mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix: restore all removed bundled skills + fix skills sync system
- Restored 21 skills removed in commits757d012and740dd92: accelerate, audiocraft, code-review, faiss, flash-attention, gguf, grpo-rl-training, guidance, llava, nemo-curator, obliteratus, peft, pytorch-fsdp, pytorch-lightning, simpo, slime, stable-diffusion, tensorrt-llm, torchtitan, trl-fine-tuning, whisper - Rewrote sync_skills() with proper update semantics: * New skills (not in manifest): copied to user dir * Existing skills (in manifest + on disk): updated via hash comparison * User-deleted skills (in manifest, not on disk): respected, not re-added * Stale manifest entries (removed from bundled): cleaned from manifest - Added sync_skills() to CLI startup (cmd_chat) and gateway startup (start_gateway) — previously only ran during 'hermes update' - Updated cmd_update output to show new/updated/cleaned counts - Rewrote tests: 20 tests covering manifest CRUD, dir hashing, fresh install, user deletion respect, update detection, stale cleanup, and name collision handling 75 bundled skills total. 2002 tests pass.
This commit is contained in:
parent
68fbae5692
commit
ab0f4126cf
74 changed files with 27881 additions and 44 deletions
335
skills/mlops/accelerate/SKILL.md
Normal file
335
skills/mlops/accelerate/SKILL.md
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
---
|
||||
name: huggingface-accelerate
|
||||
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [accelerate, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
|
||||
|
||||
---
|
||||
|
||||
# HuggingFace Accelerate - Unified Distributed Training
|
||||
|
||||
## Quick start
|
||||
|
||||
Accelerate simplifies distributed training to 4 lines of code.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
**Convert PyTorch script** (4 lines):
|
||||
```python
|
||||
import torch
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
+ accelerator = Accelerator()
|
||||
|
||||
model = torch.nn.Transformer()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset)
|
||||
|
||||
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
- loss.backward()
|
||||
+ accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Run** (single command):
|
||||
```bash
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: From single GPU to multi-GPU
|
||||
|
||||
**Original script**:
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
|
||||
model = torch.nn.Linear(10, 2).to('cuda')
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
batch = batch.to('cuda')
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**With Accelerate** (4 lines added):
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
from accelerate import Accelerator # +1
|
||||
|
||||
accelerator = Accelerator() # +2
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
# No .to('cuda') needed - automatic!
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
accelerator.backward(loss) # +4
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Configure** (interactive):
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
- Which machine? (single/multi GPU/TPU/CPU)
|
||||
- How many machines? (1)
|
||||
- Mixed precision? (no/fp16/bf16/fp8)
|
||||
- DeepSpeed? (no/yes)
|
||||
|
||||
**Launch** (works on any setup):
|
||||
```bash
|
||||
# Single GPU
|
||||
accelerate launch train.py
|
||||
|
||||
# Multi-GPU (8 GPUs)
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
train.py
|
||||
```
|
||||
|
||||
### Workflow 2: Mixed precision training
|
||||
|
||||
**Enable FP16/BF16**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# FP16 (with gradient scaling)
|
||||
accelerator = Accelerator(mixed_precision='fp16')
|
||||
|
||||
# BF16 (no scaling, more stable)
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# FP8 (H100+)
|
||||
accelerator = Accelerator(mixed_precision='fp8')
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
# Everything else is automatic!
|
||||
for batch in dataloader:
|
||||
with accelerator.autocast(): # Optional, done automatically
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
```
|
||||
|
||||
### Workflow 3: DeepSpeed ZeRO integration
|
||||
|
||||
**Enable DeepSpeed ZeRO-2**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
deepspeed_plugin={
|
||||
"zero_stage": 2, # ZeRO-2
|
||||
"offload_optimizer": False,
|
||||
"gradient_accumulation_steps": 4
|
||||
}
|
||||
)
|
||||
|
||||
# Same code as before!
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: DeepSpeed → ZeRO-2
|
||||
```
|
||||
|
||||
**deepspeed_config.json**:
|
||||
```json
|
||||
{
|
||||
"fp16": {"enabled": false},
|
||||
"bf16": {"enabled": true},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {"device": "cpu"},
|
||||
"allgather_bucket_size": 5e8,
|
||||
"reduce_bucket_size": 5e8
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
accelerate launch --config_file deepspeed_config.json train.py
|
||||
```
|
||||
|
||||
### Workflow 4: FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**Enable FSDP**:
|
||||
```python
|
||||
from accelerate import Accelerator, FullyShardedDataParallelPlugin
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
||||
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
|
||||
cpu_offload=False
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
fsdp_plugin=fsdp_plugin
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: FSDP → Full Shard → No CPU Offload
|
||||
```
|
||||
|
||||
### Workflow 5: Gradient accumulation
|
||||
|
||||
**Accumulate gradients**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=4)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Accelerate when**:
|
||||
- Want simplest distributed training
|
||||
- Need single script for any hardware
|
||||
- Use HuggingFace ecosystem
|
||||
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
|
||||
- Need quick prototyping
|
||||
|
||||
**Key advantages**:
|
||||
- **4 lines**: Minimal code changes
|
||||
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
|
||||
- **Automatic**: Device placement, mixed precision, sharding
|
||||
- **Interactive config**: No manual launcher setup
|
||||
- **Single launch**: Works everywhere
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **PyTorch Lightning**: Need callbacks, high-level abstractions
|
||||
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
|
||||
- **DeepSpeed**: Direct API control, advanced features
|
||||
- **Raw DDP**: Maximum control, minimal abstraction
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Wrong device placement**
|
||||
|
||||
Don't manually move to device:
|
||||
```python
|
||||
# WRONG
|
||||
batch = batch.to('cuda')
|
||||
|
||||
# CORRECT
|
||||
# Accelerate handles it automatically after prepare()
|
||||
```
|
||||
|
||||
**Issue: Gradient accumulation not working**
|
||||
|
||||
Use context manager:
|
||||
```python
|
||||
# CORRECT
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Issue: Checkpointing in distributed**
|
||||
|
||||
Use accelerator methods:
|
||||
```python
|
||||
# Save only on main process
|
||||
if accelerator.is_main_process:
|
||||
accelerator.save_state('checkpoint/')
|
||||
|
||||
# Load on all processes
|
||||
accelerator.load_state('checkpoint/')
|
||||
```
|
||||
|
||||
**Issue: Different results with FSDP**
|
||||
|
||||
Ensure same random seed:
|
||||
```python
|
||||
from accelerate.utils import set_seed
|
||||
set_seed(42)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
|
||||
|
||||
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
|
||||
|
||||
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **CPU**: Works (slow)
|
||||
- **Single GPU**: Works
|
||||
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
|
||||
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
|
||||
- **TPU**: Supported
|
||||
- **Apple MPS**: Supported
|
||||
|
||||
**Launcher requirements**:
|
||||
- **DDP**: `torch.distributed.run` (built-in)
|
||||
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
|
||||
- **FSDP**: PyTorch 1.12+ (built-in)
|
||||
- **Megatron**: Custom setup
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/accelerate
|
||||
- GitHub: https://github.com/huggingface/accelerate
|
||||
- Version: 1.11.0+
|
||||
- Tutorial: "Accelerate your scripts"
|
||||
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
|
||||
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries
|
||||
|
||||
|
||||
|
||||
453
skills/mlops/accelerate/references/custom-plugins.md
Normal file
453
skills/mlops/accelerate/references/custom-plugins.md
Normal file
|
|
@ -0,0 +1,453 @@
|
|||
# Custom Plugins for Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
|
||||
|
||||
## Plugin Architecture
|
||||
|
||||
### Base Plugin Structure
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
"""Custom training plugin."""
|
||||
|
||||
# Plugin configuration
|
||||
param1: int = 1
|
||||
param2: str = "default"
|
||||
|
||||
def __post_init__(self):
|
||||
# Validation logic
|
||||
if self.param1 < 1:
|
||||
raise ValueError("param1 must be >= 1")
|
||||
```
|
||||
|
||||
### Using Custom Plugin
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# Create plugin
|
||||
custom_plugin = CustomPlugin(param1=4, param2="value")
|
||||
|
||||
# Pass to Accelerator
|
||||
accelerator = Accelerator(
|
||||
custom_plugin=custom_plugin # Not a real parameter, example only
|
||||
)
|
||||
```
|
||||
|
||||
## Built-In Plugin Examples
|
||||
|
||||
### 1. GradScalerKwargs (FP16 Configuration)
|
||||
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
# Configure gradient scaler for FP16
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16, # Initial loss scale
|
||||
growth_factor=2.0, # Scale growth rate
|
||||
backoff_factor=0.5, # Scale backoff rate
|
||||
growth_interval=2000, # Steps between scale increases
|
||||
enabled=True # Enable scaler
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Fine-tune FP16 gradient scaling behavior
|
||||
|
||||
### 2. DistributedDataParallelKwargs
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
# Configure DDP behavior
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Gradient bucketing size
|
||||
find_unused_parameters=False, # Find unused params (slower)
|
||||
check_reduction=False, # Check gradient reduction
|
||||
gradient_as_bucket_view=True, # Memory optimization
|
||||
static_graph=False # Static computation graph
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
kwargs_handlers=[ddp_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Optimize DDP performance for specific models
|
||||
|
||||
### 3. FP8RecipeKwargs (H100 FP8)
|
||||
|
||||
```python
|
||||
from accelerate.utils import FP8RecipeKwargs
|
||||
|
||||
# Configure FP8 training (H100)
|
||||
fp8_recipe = FP8RecipeKwargs(
|
||||
backend="te", # TransformerEngine backend
|
||||
margin=0, # Scaling margin
|
||||
interval=1, # Scaling interval
|
||||
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
|
||||
amax_history_len=1024, # AMAX history length
|
||||
amax_compute_algo="max" # AMAX computation algorithm
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp8',
|
||||
kwargs_handlers=[fp8_recipe]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Ultra-fast training on H100 GPUs
|
||||
|
||||
## Custom DeepSpeed Configuration
|
||||
|
||||
### ZeRO-3 with CPU Offload
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
# Custom DeepSpeed config
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3, # ZeRO-3
|
||||
offload_optimizer_device="cpu", # CPU offload optimizer
|
||||
offload_param_device="cpu", # CPU offload parameters
|
||||
zero3_init_flag=True, # ZeRO-3 initialization
|
||||
zero3_save_16bit_model=True, # Save FP16 weights
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### ZeRO-2 with NVMe Offload
|
||||
|
||||
```python
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="nvme", # NVMe offload
|
||||
offload_param_device="nvme",
|
||||
nvme_path="/local_nvme", # NVMe mount path
|
||||
)
|
||||
```
|
||||
|
||||
### Custom JSON Config
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Load custom DeepSpeed config
|
||||
with open('deepspeed_config.json', 'r') as f:
|
||||
ds_config = json.load(f)
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
|
||||
|
||||
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
|
||||
```
|
||||
|
||||
**Example config** (`deepspeed_config.json`):
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"steps_per_print": 100,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
```
|
||||
|
||||
## Custom FSDP Configuration
|
||||
|
||||
### FSDP with Custom Auto-Wrap Policy
|
||||
|
||||
```python
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin
|
||||
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||
import functools
|
||||
|
||||
# Custom wrap policy (size-based)
|
||||
wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy,
|
||||
min_num_params=1e6 # Wrap layers with 1M+ params
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
|
||||
mixed_precision_policy=None, # Use Accelerator's mixed precision
|
||||
auto_wrap_policy=wrap_policy, # Custom wrapping
|
||||
cpu_offload=False,
|
||||
ignored_modules=None, # Modules to not wrap
|
||||
state_dict_type="FULL_STATE_DICT", # Save format
|
||||
optim_state_dict_config=None,
|
||||
limit_all_gathers=False,
|
||||
use_orig_params=True, # Use original param shapes
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
fsdp_plugin=fsdp_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### FSDP with Transformer Auto-Wrap
|
||||
|
||||
```python
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
|
||||
# Wrap at transformer block level
|
||||
wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=wrap_policy
|
||||
)
|
||||
```
|
||||
|
||||
## Creating Custom Training Strategy
|
||||
|
||||
### Example: Custom Gradient Accumulation
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
class CustomGradientAccumulation:
|
||||
def __init__(self, steps=4, adaptive=False):
|
||||
self.steps = steps
|
||||
self.adaptive = adaptive
|
||||
self.current_step = 0
|
||||
|
||||
def should_sync(self, loss):
|
||||
"""Decide whether to sync gradients."""
|
||||
self.current_step += 1
|
||||
|
||||
# Adaptive: sync on high loss
|
||||
if self.adaptive and loss > threshold:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
# Regular: sync every N steps
|
||||
if self.current_step >= self.steps:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Usage
|
||||
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
|
||||
accelerator = Accelerator()
|
||||
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# Scale loss
|
||||
loss = loss / custom_accum.steps
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Conditional sync
|
||||
if custom_accum.should_sync(loss.item()):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### Example: Custom Mixed Precision
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class CustomMixedPrecision:
|
||||
"""Custom mixed precision with dynamic loss scaling."""
|
||||
|
||||
def __init__(self, init_scale=2**16, scale_window=2000):
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=init_scale,
|
||||
growth_interval=scale_window
|
||||
)
|
||||
self.scale_history = []
|
||||
|
||||
def scale_loss(self, loss):
|
||||
"""Scale loss for backward."""
|
||||
return self.scaler.scale(loss)
|
||||
|
||||
def unscale_and_clip(self, optimizer, max_norm=1.0):
|
||||
"""Unscale gradients and clip."""
|
||||
self.scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
optimizer.param_groups[0]['params'],
|
||||
max_norm
|
||||
)
|
||||
|
||||
def step(self, optimizer):
|
||||
"""Optimizer step with scaler update."""
|
||||
scale_before = self.scaler.get_scale()
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.update()
|
||||
scale_after = self.scaler.get_scale()
|
||||
|
||||
# Track scale changes
|
||||
if scale_before != scale_after:
|
||||
self.scale_history.append(scale_after)
|
||||
|
||||
# Usage
|
||||
custom_mp = CustomMixedPrecision()
|
||||
|
||||
for batch in dataloader:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
loss = model(**batch).loss
|
||||
|
||||
scaled_loss = custom_mp.scale_loss(loss)
|
||||
scaled_loss.backward()
|
||||
|
||||
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
|
||||
custom_mp.step(optimizer)
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Advanced: Custom Distributed Backend
|
||||
|
||||
### Custom AllReduce Strategy
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
|
||||
class CustomAllReduce:
|
||||
"""Custom all-reduce with compression."""
|
||||
|
||||
def __init__(self, compression_ratio=0.1):
|
||||
self.compression_ratio = compression_ratio
|
||||
|
||||
def compress_gradients(self, tensor):
|
||||
"""Top-k gradient compression."""
|
||||
k = int(tensor.numel() * self.compression_ratio)
|
||||
values, indices = torch.topk(tensor.abs().view(-1), k)
|
||||
return values, indices
|
||||
|
||||
def all_reduce_compressed(self, tensor):
|
||||
"""All-reduce with gradient compression."""
|
||||
# Compress
|
||||
values, indices = self.compress_gradients(tensor)
|
||||
|
||||
# All-reduce compressed gradients
|
||||
dist.all_reduce(values, op=dist.ReduceOp.SUM)
|
||||
|
||||
# Decompress
|
||||
tensor_compressed = torch.zeros_like(tensor).view(-1)
|
||||
tensor_compressed[indices] = values / dist.get_world_size()
|
||||
|
||||
return tensor_compressed.view_as(tensor)
|
||||
|
||||
# Usage in training loop
|
||||
custom_ar = CustomAllReduce(compression_ratio=0.1)
|
||||
|
||||
for batch in dataloader:
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
|
||||
# Custom all-reduce
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Plugin Best Practices
|
||||
|
||||
### 1. Validation in `__post_init__`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
learning_rate: float = 1e-3
|
||||
warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate parameters
|
||||
if self.learning_rate <= 0:
|
||||
raise ValueError("learning_rate must be positive")
|
||||
if self.warmup_steps < 0:
|
||||
raise ValueError("warmup_steps must be non-negative")
|
||||
|
||||
# Compute derived values
|
||||
self.min_lr = self.learning_rate * 0.1
|
||||
```
|
||||
|
||||
### 2. Compatibility Checks
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
feature_enabled: bool = True
|
||||
|
||||
def is_compatible(self, accelerator):
|
||||
"""Check if plugin is compatible with accelerator config."""
|
||||
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
|
||||
raise ValueError("Custom plugin not compatible with FP8")
|
||||
return True
|
||||
```
|
||||
|
||||
### 3. State Management
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
counter: int = 0
|
||||
history: list = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.history is None:
|
||||
self.history = []
|
||||
|
||||
def update_state(self, value):
|
||||
"""Update plugin state during training."""
|
||||
self.counter += 1
|
||||
self.history.append(value)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
|
||||
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
|
||||
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
|
||||
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu
|
||||
489
skills/mlops/accelerate/references/megatron-integration.md
Normal file
489
skills/mlops/accelerate/references/megatron-integration.md
Normal file
|
|
@ -0,0 +1,489 @@
|
|||
# Megatron Integration with Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
|
||||
|
||||
**Megatron capabilities**:
|
||||
- **Tensor Parallelism (TP)**: Split layers across GPUs
|
||||
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
|
||||
- **Data Parallelism (DP)**: Replicate model across GPU groups
|
||||
- **Sequence Parallelism**: Split sequences for long contexts
|
||||
|
||||
## Setup
|
||||
|
||||
### Install Megatron-LM
|
||||
|
||||
```bash
|
||||
# Clone Megatron-LM repository
|
||||
git clone https://github.com/NVIDIA/Megatron-LM.git
|
||||
cd Megatron-LM
|
||||
pip install -e .
|
||||
|
||||
# Install Apex (NVIDIA optimizations)
|
||||
git clone https://github.com/NVIDIA/apex
|
||||
cd apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
|
||||
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
```
|
||||
|
||||
### Accelerate Configuration
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
```
|
||||
In which compute environment are you running?
|
||||
> This machine
|
||||
|
||||
Which type of machine are you using?
|
||||
> Multi-GPU
|
||||
|
||||
How many different machines will you use?
|
||||
> 1
|
||||
|
||||
Do you want to use DeepSpeed/FSDP?
|
||||
> No
|
||||
|
||||
Do you want to use Megatron-LM?
|
||||
> Yes
|
||||
|
||||
What is the Tensor Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Do you want to enable Sequence Parallelism?
|
||||
> No
|
||||
|
||||
What is the Pipeline Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
What is the Data Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
|
||||
> SELECTIVE
|
||||
|
||||
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
|
||||
> SEQUENTIAL
|
||||
```
|
||||
|
||||
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MEGATRON_LM
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
megatron_lm_config:
|
||||
megatron_lm_gradient_clipping: 1.0
|
||||
megatron_lm_learning_rate_decay_iters: 320000
|
||||
megatron_lm_num_micro_batches: 1
|
||||
megatron_lm_pp_degree: 2
|
||||
megatron_lm_recompute_activations: true
|
||||
megatron_lm_sequence_parallelism: false
|
||||
megatron_lm_tp_degree: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Tensor Parallelism (TP)
|
||||
|
||||
**Splits each transformer layer across GPUs**:
|
||||
|
||||
```python
|
||||
# Layer split across 2 GPUs
|
||||
# GPU 0: First half of attention heads
|
||||
# GPU 1: Second half of attention heads
|
||||
|
||||
# Each GPU computes partial outputs
|
||||
# All-reduce combines results
|
||||
```
|
||||
|
||||
**TP degree recommendations**:
|
||||
- **TP=1**: No tensor parallelism (single GPU per layer)
|
||||
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
|
||||
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
|
||||
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Reduces memory per GPU
|
||||
- All-reduce communication (fast)
|
||||
|
||||
**Drawbacks**:
|
||||
- Requires fast inter-GPU bandwidth (NVLink)
|
||||
- Communication overhead per layer
|
||||
|
||||
### Pipeline Parallelism (PP)
|
||||
|
||||
**Splits model depth across GPUs**:
|
||||
|
||||
```python
|
||||
# 12-layer model, PP=4
|
||||
# GPU 0: Layers 0-2
|
||||
# GPU 1: Layers 3-5
|
||||
# GPU 2: Layers 6-8
|
||||
# GPU 3: Layers 9-11
|
||||
```
|
||||
|
||||
**PP degree recommendations**:
|
||||
- **PP=1**: No pipeline parallelism
|
||||
- **PP=2**: 2 pipeline stages (good for 20-40B models)
|
||||
- **PP=4**: 4 pipeline stages (good for 70B+ models)
|
||||
- **PP=8**: 8 pipeline stages (good for 175B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Linear memory reduction (4× PP = 4× less memory)
|
||||
- Works across nodes (slower interconnect OK)
|
||||
|
||||
**Drawbacks**:
|
||||
- Pipeline bubbles (idle time)
|
||||
- Requires micro-batching
|
||||
|
||||
### Data Parallelism (DP)
|
||||
|
||||
**Replicates model across GPU groups**:
|
||||
|
||||
```python
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
# Group 0 (GPUs 0-3): Full model replica
|
||||
# Group 1 (GPUs 4-7): Full model replica
|
||||
```
|
||||
|
||||
**DP degree**:
|
||||
- `DP = total_gpus / (TP × PP)`
|
||||
- Example: 8 GPUs, TP=2, PP=2 → DP=2
|
||||
|
||||
**Benefits**:
|
||||
- Increases throughput
|
||||
- Scales batch size
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
**Splits long sequences across GPUs** (extends TP):
|
||||
|
||||
```python
|
||||
# 8K sequence, TP=2, Sequence Parallel=True
|
||||
# GPU 0: Tokens 0-4095
|
||||
# GPU 1: Tokens 4096-8191
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Enables very long sequences (100K+ tokens)
|
||||
- Reduces activation memory
|
||||
|
||||
**Requirements**:
|
||||
- Must use with TP > 1
|
||||
- RoPE/ALiBi position encodings work best
|
||||
|
||||
## Accelerate Code Example
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
|
||||
# Configure Megatron
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2, # Tensor parallelism degree
|
||||
pp_degree=2, # Pipeline parallelism degree
|
||||
num_micro_batches=4, # Micro-batches for pipeline
|
||||
gradient_clipping=1.0, # Gradient clipping value
|
||||
sequence_parallelism=False, # Enable sequence parallelism
|
||||
recompute_activations=True, # Activation checkpointing
|
||||
use_distributed_optimizer=True, # Distributed optimizer
|
||||
custom_prepare_model_function=None, # Custom model prep
|
||||
)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Prepare model and optimizer
|
||||
model, optimizer, train_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Training loop (same as DDP!)
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
### Full Training Script
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
def main():
|
||||
# Megatron configuration
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2,
|
||||
pp_degree=2,
|
||||
num_micro_batches=4,
|
||||
gradient_clipping=1.0,
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
gradient_accumulation_steps=8,
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Model
|
||||
config = GPT2Config(
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
n_embd=1024,
|
||||
)
|
||||
model = GPT2LMHeadModel(config)
|
||||
|
||||
# Optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
|
||||
|
||||
# Prepare
|
||||
model, optimizer, train_loader = accelerator.prepare(
|
||||
model, optimizer, train_loader
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
for batch in train_loader:
|
||||
with accelerator.accumulate(model):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Save checkpoint
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.save_state(f'checkpoint-epoch-{epoch}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
### Launch Command
|
||||
|
||||
```bash
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node (2 nodes, 8 GPUs each)
|
||||
# Node 0
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
|
||||
# Node 1
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 1 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
```
|
||||
|
||||
## Activation Checkpointing
|
||||
|
||||
**Reduces memory by recomputing activations**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
recompute_activations=True, # Enable checkpointing
|
||||
checkpoint_num_layers=1, # Checkpoint every N layers
|
||||
distribute_checkpointed_activations=True, # Distribute across TP
|
||||
partition_activations=True, # Partition in PP
|
||||
check_for_nan_in_loss_and_grad=True, # Stability check
|
||||
)
|
||||
```
|
||||
|
||||
**Strategies**:
|
||||
- `SELECTIVE`: Checkpoint transformer blocks only
|
||||
- `FULL`: Checkpoint all layers
|
||||
- `NONE`: No checkpointing
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
## Distributed Optimizer
|
||||
|
||||
**Shards optimizer state across DP ranks**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
use_distributed_optimizer=True, # Enable sharded optimizer
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Reduces optimizer memory by DP degree
|
||||
- Example: DP=4 → 4× less optimizer memory per GPU
|
||||
|
||||
**Compatible with**:
|
||||
- AdamW, Adam, SGD
|
||||
- Mixed precision training
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Micro-Batch Size
|
||||
|
||||
```python
|
||||
# Pipeline parallelism requires micro-batching
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # 16 micro-batches per pipeline
|
||||
)
|
||||
|
||||
# Effective batch = num_micro_batches × micro_batch_size × DP
|
||||
# Example: 16 × 2 × 4 = 128
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- More micro-batches → less pipeline bubble
|
||||
- Typical: 4-16 micro-batches
|
||||
|
||||
### Sequence Length
|
||||
|
||||
```python
|
||||
# For long sequences, enable sequence parallelism
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=4,
|
||||
sequence_parallelism=True, # Required: TP > 1
|
||||
)
|
||||
|
||||
# Enables sequences up to TP × normal limit
|
||||
# Example: TP=4, 8K normal → 32K with sequence parallel
|
||||
```
|
||||
|
||||
### GPU Topology
|
||||
|
||||
**NVLink required for TP**:
|
||||
```bash
|
||||
# Check NVLink topology
|
||||
nvidia-smi topo -m
|
||||
|
||||
# Good topology (NVLink between all GPUs)
|
||||
# GPU0 - GPU1: NV12 (fast)
|
||||
# GPU0 - GPU2: NV12 (fast)
|
||||
|
||||
# Bad topology (PCIe only)
|
||||
# GPU0 - GPU4: PHB (slow, avoid TP across these)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- **TP**: Within same node (NVLink)
|
||||
- **PP**: Across nodes (slower interconnect OK)
|
||||
- **DP**: Any topology
|
||||
|
||||
## Model Size Guidelines
|
||||
|
||||
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|
||||
|------------|------|----|----|----|--------------|
|
||||
| 7B | 8 | 1 | 1 | 8 | 1 |
|
||||
| 13B | 8 | 2 | 1 | 4 | 1 |
|
||||
| 20B | 16 | 4 | 1 | 4 | 1 |
|
||||
| 40B | 32 | 4 | 2 | 4 | 4 |
|
||||
| 70B | 64 | 8 | 2 | 4 | 8 |
|
||||
| 175B | 128 | 8 | 4 | 4 | 16 |
|
||||
|
||||
**Assumptions**: BF16, 2K sequence length, A100 80GB
|
||||
|
||||
## Checkpointing
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
# Save full model state
|
||||
accelerator.save_state('checkpoint-1000')
|
||||
|
||||
# Megatron saves separate files per rank
|
||||
# checkpoint-1000/
|
||||
# pytorch_model_tp_0_pp_0.bin
|
||||
# pytorch_model_tp_0_pp_1.bin
|
||||
# pytorch_model_tp_1_pp_0.bin
|
||||
# pytorch_model_tp_1_pp_1.bin
|
||||
# optimizer_tp_0_pp_0.bin
|
||||
# ...
|
||||
```
|
||||
|
||||
### Load Checkpoint
|
||||
|
||||
```python
|
||||
# Resume training
|
||||
accelerator.load_state('checkpoint-1000')
|
||||
|
||||
# Automatically loads correct shard per rank
|
||||
```
|
||||
|
||||
### Convert to Standard PyTorch
|
||||
|
||||
```bash
|
||||
# Merge Megatron checkpoint to single file
|
||||
python merge_megatron_checkpoint.py \
|
||||
--checkpoint-dir checkpoint-1000 \
|
||||
--output pytorch_model.bin
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: OOM with Pipeline Parallelism
|
||||
|
||||
**Solution**: Increase micro-batches
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # Increase from 4
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow Training
|
||||
|
||||
**Check 1**: Pipeline bubbles (PP too high)
|
||||
```python
|
||||
# Reduce PP, increase TP
|
||||
tp_degree=4 # Increase
|
||||
pp_degree=2 # Decrease
|
||||
```
|
||||
|
||||
**Check 2**: Micro-batch size too small
|
||||
```python
|
||||
num_micro_batches=8 # Increase
|
||||
```
|
||||
|
||||
### Issue: NVLink Not Detected
|
||||
|
||||
```bash
|
||||
# Verify NVLink
|
||||
nvidia-smi nvlink -s
|
||||
|
||||
# If no NVLink, avoid TP > 1
|
||||
# Use PP or DP instead
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
|
||||
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
|
||||
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
|
||||
- NVIDIA Apex: https://github.com/NVIDIA/apex
|
||||
525
skills/mlops/accelerate/references/performance.md
Normal file
525
skills/mlops/accelerate/references/performance.md
Normal file
|
|
@ -0,0 +1,525 @@
|
|||
# Accelerate Performance Tuning
|
||||
|
||||
## Profiling
|
||||
|
||||
### Basic Profiling
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
import time
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
batch = next(iter(dataloader))
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Profile training loop
|
||||
start = time.time()
|
||||
total_batches = 100
|
||||
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= total_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone() # Sync all processes
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
batches_per_sec = total_batches / elapsed
|
||||
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
|
||||
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
|
||||
print(f"Batches/sec: {batches_per_sec:.2f}")
|
||||
```
|
||||
|
||||
### PyTorch Profiler Integration
|
||||
|
||||
```python
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True
|
||||
) as prof:
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10: # Profile first 10 batches
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Print profiling results
|
||||
print(prof.key_averages().table(
|
||||
sort_by="cuda_time_total", row_limit=20
|
||||
))
|
||||
|
||||
# Export to Chrome tracing
|
||||
prof.export_chrome_trace("trace.json")
|
||||
# View at chrome://tracing
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### 1. Gradient Accumulation
|
||||
|
||||
**Problem**: Large batch size causes OOM
|
||||
|
||||
**Solution**: Accumulate gradients across micro-batches
|
||||
|
||||
```python
|
||||
accelerator = Accelerator(gradient_accumulation_steps=8)
|
||||
|
||||
# Effective batch = batch_size × accumulation_steps × num_gpus
|
||||
# Example: 4 × 8 × 8 = 256
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation logic
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
|
||||
|
||||
### 2. Gradient Checkpointing
|
||||
|
||||
**Enable in model**:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
use_cache=False # Required for gradient checkpointing
|
||||
)
|
||||
|
||||
# Enable checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare with Accelerate
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
### 3. Mixed Precision
|
||||
|
||||
**BF16 (A100/H100)**:
|
||||
```python
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# Automatic mixed precision
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch) # Forward in BF16
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss) # Backward in FP32
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**FP16 (V100, older GPUs)**:
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16,
|
||||
growth_interval=2000
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% compared to FP32
|
||||
|
||||
### 4. CPU Offloading (DeepSpeed)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3,
|
||||
offload_optimizer_device="cpu", # Offload optimizer to CPU
|
||||
offload_param_device="cpu", # Offload parameters to CPU
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
|
||||
|
||||
**Trade-off**: 20-30% slower due to CPU-GPU transfers
|
||||
|
||||
### 5. Flash Attention
|
||||
|
||||
```python
|
||||
# Install flash-attn
|
||||
# pip install flash-attn
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
attn_implementation="flash_attention_2" # Enable Flash Attention 2
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% for attention, 2× faster
|
||||
|
||||
**Requirements**: A100/H100, sequence length must be multiple of 128
|
||||
|
||||
## Communication Optimization
|
||||
|
||||
### 1. Gradient Bucketing (DDP)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Bucket size for gradient reduction
|
||||
gradient_as_bucket_view=True, # Reduce memory copies
|
||||
static_graph=False # Set True if model doesn't change
|
||||
)
|
||||
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
```
|
||||
|
||||
**Recommended bucket sizes**:
|
||||
- Small models (<1B): 25 MB
|
||||
- Medium models (1-10B): 50-100 MB
|
||||
- Large models (>10B): 100-200 MB
|
||||
|
||||
### 2. Find Unused Parameters
|
||||
|
||||
```python
|
||||
# Only enable if model has unused parameters (slower!)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
find_unused_parameters=True
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Models with conditional branches (e.g., mixture of experts)
|
||||
|
||||
**Cost**: 10-20% slower
|
||||
|
||||
### 3. NCCL Tuning
|
||||
|
||||
```bash
|
||||
# Set environment variables before launch
|
||||
export NCCL_DEBUG=INFO # Debug info
|
||||
export NCCL_IB_DISABLE=0 # Enable InfiniBand
|
||||
export NCCL_SOCKET_IFNAME=eth0 # Network interface
|
||||
export NCCL_P2P_LEVEL=NVL # Use NVLink
|
||||
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
**NCCL_P2P_LEVEL options**:
|
||||
- `NVL`: NVLink (fastest, within node)
|
||||
- `PIX`: PCIe (fast, within node)
|
||||
- `PHB`: PCIe host bridge (slow, cross-node)
|
||||
|
||||
## Data Loading Optimization
|
||||
|
||||
### 1. DataLoader Workers
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Parallel data loading
|
||||
pin_memory=True, # Pin memory for faster GPU transfer
|
||||
prefetch_factor=2, # Prefetch batches per worker
|
||||
persistent_workers=True # Keep workers alive between epochs
|
||||
)
|
||||
|
||||
train_loader = accelerator.prepare(train_loader)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
|
||||
- `pin_memory`: Always True for GPU training
|
||||
- `prefetch_factor`: 2-4 (higher for slow data loading)
|
||||
|
||||
### 2. Data Preprocessing
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Bad: Preprocess during training (slow)
|
||||
dataset = load_dataset("openwebtext")
|
||||
|
||||
for batch in dataset:
|
||||
tokens = tokenizer(batch['text']) # Slow!
|
||||
...
|
||||
|
||||
# Good: Preprocess once, save
|
||||
dataset = load_dataset("openwebtext")
|
||||
tokenized = dataset.map(
|
||||
lambda x: tokenizer(x['text']),
|
||||
batched=True,
|
||||
num_proc=8, # Parallel preprocessing
|
||||
remove_columns=['text']
|
||||
)
|
||||
tokenized.save_to_disk("preprocessed_data")
|
||||
|
||||
# Load preprocessed
|
||||
dataset = load_from_disk("preprocessed_data")
|
||||
```
|
||||
|
||||
### 3. Faster Tokenization
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Enable Rust-based tokenizers (10× faster)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"gpt2",
|
||||
use_fast=True # Use fast Rust tokenizer
|
||||
)
|
||||
```
|
||||
|
||||
## Compilation (PyTorch 2.0+)
|
||||
|
||||
### Compile Model
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Compile model for faster execution
|
||||
model = torch.compile(
|
||||
model,
|
||||
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
|
||||
fullgraph=False, # Compile entire graph (stricter)
|
||||
dynamic=True # Support dynamic shapes
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Speedup**: 10-50% depending on model
|
||||
|
||||
**Compilation modes**:
|
||||
- `default`: Balanced (best for most cases)
|
||||
- `reduce-overhead`: Min overhead (best for small batches)
|
||||
- `max-autotune`: Max performance (slow compile, best for production)
|
||||
|
||||
### Compilation Best Practices
|
||||
|
||||
```python
|
||||
# Bad: Compile after prepare (won't work)
|
||||
model = accelerator.prepare(model)
|
||||
model = torch.compile(model) # Error!
|
||||
|
||||
# Good: Compile before prepare
|
||||
model = torch.compile(model)
|
||||
model = accelerator.prepare(model)
|
||||
|
||||
# Training loop
|
||||
for batch in dataloader:
|
||||
# First iteration: slow (compilation)
|
||||
# Subsequent iterations: fast (compiled)
|
||||
outputs = model(**batch)
|
||||
...
|
||||
```
|
||||
|
||||
## Benchmarking Different Strategies
|
||||
|
||||
### Script Template
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def benchmark_strategy(strategy_name, accelerator_kwargs):
|
||||
"""Benchmark a specific training strategy."""
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
# Setup
|
||||
model = create_model()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
dataloader = create_dataloader()
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(
|
||||
model, optimizer, dataloader
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10:
|
||||
break
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Benchmark
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
||||
num_batches = 100
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= num_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(f"\n{strategy_name}:")
|
||||
print(f" Throughput: {throughput:.2f} samples/sec")
|
||||
print(f" Memory: {memory_used:.2f} GB")
|
||||
print(f" Time: {elapsed:.2f} sec")
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Benchmark different strategies
|
||||
strategies = [
|
||||
("DDP + FP32", {}),
|
||||
("DDP + BF16", {"mixed_precision": "bf16"}),
|
||||
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
|
||||
("FSDP", {"fsdp_plugin": fsdp_plugin}),
|
||||
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
|
||||
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
|
||||
]
|
||||
|
||||
for name, kwargs in strategies:
|
||||
benchmark_strategy(name, kwargs)
|
||||
```
|
||||
|
||||
## Performance Checklist
|
||||
|
||||
**Before training**:
|
||||
- [ ] Use BF16/FP16 mixed precision
|
||||
- [ ] Enable gradient checkpointing (if OOM)
|
||||
- [ ] Set appropriate `num_workers` (2-4 per GPU)
|
||||
- [ ] Enable `pin_memory=True`
|
||||
- [ ] Preprocess data once, not during training
|
||||
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
|
||||
|
||||
**For large models**:
|
||||
- [ ] Use FSDP or DeepSpeed ZeRO-3
|
||||
- [ ] Enable CPU offloading (if still OOM)
|
||||
- [ ] Use Flash Attention
|
||||
- [ ] Increase gradient accumulation
|
||||
|
||||
**For multi-node**:
|
||||
- [ ] Check network topology (InfiniBand > Ethernet)
|
||||
- [ ] Tune NCCL settings
|
||||
- [ ] Use larger bucket sizes for DDP
|
||||
- [ ] Verify NVLink for tensor parallelism
|
||||
|
||||
**Profiling**:
|
||||
- [ ] Profile first 10-100 batches
|
||||
- [ ] Check GPU utilization (`nvidia-smi dmon`)
|
||||
- [ ] Check data loading time (should be <5% of iteration)
|
||||
- [ ] Identify communication bottlenecks
|
||||
|
||||
## Common Performance Issues
|
||||
|
||||
### Issue: Low GPU Utilization (<80%)
|
||||
|
||||
**Cause 1**: Data loading bottleneck
|
||||
```python
|
||||
# Solution: Increase workers and prefetch
|
||||
num_workers=8
|
||||
prefetch_factor=4
|
||||
```
|
||||
|
||||
**Cause 2**: Small batch size
|
||||
```python
|
||||
# Solution: Increase batch size or use gradient accumulation
|
||||
batch_size=32 # Increase
|
||||
gradient_accumulation_steps=4 # Or accumulate
|
||||
```
|
||||
|
||||
### Issue: High Memory Usage
|
||||
|
||||
**Solution 1**: Gradient checkpointing
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Solution 2**: Reduce batch size, increase accumulation
|
||||
```python
|
||||
batch_size=8 # Reduce from 32
|
||||
gradient_accumulation_steps=16 # Maintain effective batch
|
||||
```
|
||||
|
||||
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
|
||||
```python
|
||||
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
|
||||
```
|
||||
|
||||
### Issue: Slow Multi-GPU Training
|
||||
|
||||
**Cause**: Communication bottleneck
|
||||
|
||||
**Check 1**: Gradient bucket size
|
||||
```python
|
||||
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
|
||||
```
|
||||
|
||||
**Check 2**: NCCL settings
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
|
||||
```
|
||||
|
||||
**Check 3**: Network bandwidth
|
||||
```bash
|
||||
# Test inter-GPU bandwidth
|
||||
nvidia-smi nvlink -s
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
|
||||
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
|
||||
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
|
||||
- Flash Attention: https://github.com/Dao-AILab/flash-attention
|
||||
567
skills/mlops/audiocraft/SKILL.md
Normal file
567
skills/mlops/audiocraft/SKILL.md
Normal file
|
|
@ -0,0 +1,567 @@
|
|||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
|
||||
|
||||
---
|
||||
|
||||
# AudioCraft: Audio Generation
|
||||
|
||||
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
|
||||
|
||||
## When to use AudioCraft
|
||||
|
||||
**Use AudioCraft when:**
|
||||
- Need to generate music from text descriptions
|
||||
- Creating sound effects and environmental audio
|
||||
- Building music generation applications
|
||||
- Need melody-conditioned music generation
|
||||
- Want stereo audio output
|
||||
- Require controllable music generation with style transfer
|
||||
|
||||
**Key features:**
|
||||
- **MusicGen**: Text-to-music generation with melody conditioning
|
||||
- **AudioGen**: Text-to-sound effects generation
|
||||
- **EnCodec**: High-fidelity neural audio codec
|
||||
- **Multiple model sizes**: Small (300M) to Large (3.3B)
|
||||
- **Stereo support**: Full stereo audio generation
|
||||
- **Style conditioning**: MusicGen-Style for reference-based generation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Stable Audio**: For longer commercial music generation
|
||||
- **Bark**: For text-to-speech with music/sound effects
|
||||
- **Riffusion**: For spectogram-based music generation
|
||||
- **OpenAI Jukebox**: For raw audio generation with lyrics
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# From GitHub (latest)
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Or use HuggingFace Transformers
|
||||
pip install transformers torch torchaudio
|
||||
```
|
||||
|
||||
### Basic text-to-music (AudioCraft)
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=8, # seconds
|
||||
top_k=250,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Generate from text
|
||||
descriptions = ["happy upbeat electronic dance music with synths"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save audio
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Using HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
import scipy
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
model.to("cuda")
|
||||
|
||||
# Generate music
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
guidance_scale=3,
|
||||
max_new_tokens=256
|
||||
)
|
||||
|
||||
# Save
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
|
||||
```
|
||||
|
||||
### Text-to-sound with AudioGen
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
# Load AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
model.set_generation_params(duration=5)
|
||||
|
||||
# Generate sound effects
|
||||
descriptions = ["dog barking in a park with birds chirping"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Architecture overview
|
||||
|
||||
```
|
||||
AudioCraft Architecture:
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Text Encoder (T5) │
|
||||
│ │ │
|
||||
│ Text Embeddings │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ Transformer Decoder (LM) │
|
||||
│ Auto-regressively generates audio tokens │
|
||||
│ Using efficient token interleaving patterns │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ EnCodec Audio Decoder │
|
||||
│ Converts tokens back to audio waveform │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Size | Description | Use Case |
|
||||
|-------|------|-------------|----------|
|
||||
| `musicgen-small` | 300M | Text-to-music | Quick generation |
|
||||
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
|
||||
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
|
||||
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
|
||||
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
|
||||
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
|
||||
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
|
||||
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `duration` | 8.0 | Length in seconds (1-120) |
|
||||
| `top_k` | 250 | Top-k sampling |
|
||||
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
|
||||
| `temperature` | 1.0 | Sampling temperature |
|
||||
| `cfg_coef` | 3.0 | Classifier-free guidance |
|
||||
|
||||
## MusicGen usage
|
||||
|
||||
### Text-to-music generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Configure generation
|
||||
model.set_generation_params(
|
||||
duration=30, # Up to 30 seconds
|
||||
top_k=250, # Sampling diversity
|
||||
top_p=0.0, # 0 = use top_k only
|
||||
temperature=1.0, # Creativity (higher = more varied)
|
||||
cfg_coef=3.0 # Text adherence (higher = stricter)
|
||||
)
|
||||
|
||||
# Generate multiple samples
|
||||
descriptions = [
|
||||
"epic orchestral soundtrack with strings and brass",
|
||||
"chill lo-fi hip hop beat with jazzy piano",
|
||||
"energetic rock song with electric guitar"
|
||||
]
|
||||
|
||||
# Generate (returns [batch, channels, samples])
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save each
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Melody-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
# Load melody model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
model.set_generation_params(duration=30)
|
||||
|
||||
# Load melody audio
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Generate with melody conditioning
|
||||
descriptions = ["acoustic guitar folk song"]
|
||||
wav = model.generate_with_chroma(descriptions, melody, sr)
|
||||
|
||||
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Stereo generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load stereo model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
model.set_generation_params(duration=15)
|
||||
|
||||
descriptions = ["ambient electronic music with wide stereo panning"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch, 2, samples] for stereo
|
||||
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
|
||||
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Audio continuation
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
# Load audio to continue
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load("intro.wav")
|
||||
|
||||
# Process with text and audio
|
||||
inputs = processor(
|
||||
audio=audio.squeeze().numpy(),
|
||||
sampling_rate=sr,
|
||||
text=["continue with a epic chorus"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate continuation
|
||||
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
|
||||
```
|
||||
|
||||
## MusicGen-Style usage
|
||||
|
||||
### Style-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load style model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-style')
|
||||
|
||||
# Configure generation with style
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=5.0 # Style influence
|
||||
)
|
||||
|
||||
# Configure style conditioner
|
||||
model.set_style_conditioner_params(
|
||||
eval_q=3, # RVQ quantizers (1-6)
|
||||
excerpt_length=3.0 # Style excerpt length
|
||||
)
|
||||
|
||||
# Load style reference
|
||||
style_audio, sr = torchaudio.load("reference_style.wav")
|
||||
|
||||
# Generate with text + style
|
||||
descriptions = ["upbeat dance track"]
|
||||
wav = model.generate_with_style(descriptions, style_audio, sr)
|
||||
```
|
||||
|
||||
### Style-only generation (no text)
|
||||
|
||||
```python
|
||||
# Generate matching style without text prompt
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=None # Disable double CFG for style-only
|
||||
)
|
||||
|
||||
wav = model.generate_with_style([None], style_audio, sr)
|
||||
```
|
||||
|
||||
## AudioGen usage
|
||||
|
||||
### Sound effect generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate various sounds
|
||||
descriptions = [
|
||||
"thunderstorm with heavy rain and lightning",
|
||||
"busy city traffic with car horns",
|
||||
"ocean waves crashing on rocks",
|
||||
"crackling campfire in forest"
|
||||
]
|
||||
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## EnCodec usage
|
||||
|
||||
### Audio compression
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# Load EnCodec
|
||||
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Load audio
|
||||
wav, sr = torchaudio.load("audio.wav")
|
||||
|
||||
# Ensure correct sample rate
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Encode to tokens
|
||||
with torch.no_grad():
|
||||
encoded = model.encode(wav.unsqueeze(0))
|
||||
codes = encoded[0] # Audio codes
|
||||
|
||||
# Decode back to audio
|
||||
with torch.no_grad():
|
||||
decoded = model.decode(codes)
|
||||
|
||||
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Music generation pipeline
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenerator:
|
||||
def __init__(self, model_name="facebook/musicgen-medium"):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.sample_rate = 32000
|
||||
|
||||
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
|
||||
self.model.set_generation_params(
|
||||
duration=duration,
|
||||
top_k=250,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([prompt])
|
||||
|
||||
return wav[0].cpu()
|
||||
|
||||
def generate_batch(self, prompts, duration=30):
|
||||
self.model.set_generation_params(duration=duration)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(prompts)
|
||||
|
||||
return wav.cpu()
|
||||
|
||||
def save(self, audio, path):
|
||||
torchaudio.save(path, audio, sample_rate=self.sample_rate)
|
||||
|
||||
# Usage
|
||||
generator = MusicGenerator()
|
||||
audio = generator.generate(
|
||||
"epic cinematic orchestral music",
|
||||
duration=30,
|
||||
temperature=1.0
|
||||
)
|
||||
generator.save(audio, "epic_music.wav")
|
||||
```
|
||||
|
||||
### Workflow 2: Sound design batch processing
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
def batch_generate_sounds(sound_specs, output_dir):
|
||||
"""
|
||||
Generate multiple sounds from specifications.
|
||||
|
||||
Args:
|
||||
sound_specs: list of {"name": str, "description": str, "duration": float}
|
||||
output_dir: output directory path
|
||||
"""
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
for spec in sound_specs:
|
||||
model.set_generation_params(duration=spec.get("duration", 5))
|
||||
|
||||
wav = model.generate([spec["description"]])
|
||||
|
||||
output_path = output_dir / f"{spec['name']}.wav"
|
||||
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
|
||||
|
||||
results.append({
|
||||
"name": spec["name"],
|
||||
"path": str(output_path),
|
||||
"description": spec["description"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
sounds = [
|
||||
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
|
||||
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
|
||||
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
|
||||
]
|
||||
|
||||
results = batch_generate_sounds(sounds, "sound_effects/")
|
||||
```
|
||||
|
||||
### Workflow 3: Gradio demo
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
def generate_music(prompt, duration, temperature, cfg_coef):
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save to temp file
|
||||
path = "temp_output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate_music,
|
||||
inputs=[
|
||||
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
|
||||
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
|
||||
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Demo"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Memory optimization
|
||||
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Clear cache between generations
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Generate shorter durations
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Use half precision
|
||||
model = model.half()
|
||||
```
|
||||
|
||||
### Batch processing efficiency
|
||||
|
||||
```python
|
||||
# Process multiple prompts at once (more efficient)
|
||||
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
|
||||
wav = model.generate(descriptions) # Single batch
|
||||
|
||||
# Instead of
|
||||
for desc in descriptions:
|
||||
wav = model.generate([desc]) # Multiple batches (slower)
|
||||
```
|
||||
|
||||
### GPU memory requirements
|
||||
|
||||
| Model | FP32 VRAM | FP16 VRAM |
|
||||
|-------|-----------|-----------|
|
||||
| musicgen-small | ~4GB | ~2GB |
|
||||
| musicgen-medium | ~8GB | ~4GB |
|
||||
| musicgen-large | ~16GB | ~8GB |
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| CUDA OOM | Use smaller model, reduce duration |
|
||||
| Poor quality | Increase cfg_coef, better prompts |
|
||||
| Generation too short | Check max duration setting |
|
||||
| Audio artifacts | Try different temperature |
|
||||
| Stereo not working | Use stereo model variant |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/audiocraft
|
||||
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
|
||||
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
|
||||
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
|
||||
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen
|
||||
666
skills/mlops/audiocraft/references/advanced-usage.md
Normal file
666
skills/mlops/audiocraft/references/advanced-usage.md
Normal file
|
|
@ -0,0 +1,666 @@
|
|||
# AudioCraft Advanced Usage Guide
|
||||
|
||||
## Fine-tuning MusicGen
|
||||
|
||||
### Custom dataset preparation
|
||||
|
||||
```python
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torchaudio
|
||||
|
||||
def prepare_dataset(audio_dir, output_dir, metadata_file):
|
||||
"""
|
||||
Prepare dataset for MusicGen fine-tuning.
|
||||
|
||||
Directory structure:
|
||||
output_dir/
|
||||
├── audio/
|
||||
│ ├── 0001.wav
|
||||
│ ├── 0002.wav
|
||||
│ └── ...
|
||||
└── metadata.json
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
audio_output = output_dir / "audio"
|
||||
audio_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load metadata (format: {"path": "...", "description": "..."})
|
||||
with open(metadata_file) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
processed = []
|
||||
|
||||
for idx, item in enumerate(metadata):
|
||||
audio_path = Path(audio_dir) / item["path"]
|
||||
|
||||
# Load and resample to 32kHz
|
||||
wav, sr = torchaudio.load(str(audio_path))
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
# Save processed audio
|
||||
output_path = audio_output / f"{idx:04d}.wav"
|
||||
torchaudio.save(str(output_path), wav, sample_rate=32000)
|
||||
|
||||
processed.append({
|
||||
"path": str(output_path.relative_to(output_dir)),
|
||||
"description": item["description"],
|
||||
"duration": wav.shape[1] / 32000
|
||||
})
|
||||
|
||||
# Save processed metadata
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(processed, f, indent=2)
|
||||
|
||||
print(f"Processed {len(processed)} samples")
|
||||
return processed
|
||||
```
|
||||
|
||||
### Fine-tuning with dora
|
||||
|
||||
```bash
|
||||
# AudioCraft uses dora for experiment management
|
||||
# Install dora
|
||||
pip install dora-search
|
||||
|
||||
# Clone AudioCraft
|
||||
git clone https://github.com/facebookresearch/audiocraft.git
|
||||
cd audiocraft
|
||||
|
||||
# Create config for fine-tuning
|
||||
cat > config/solver/musicgen/finetune.yaml << 'EOF'
|
||||
defaults:
|
||||
- musicgen/musicgen_base
|
||||
- /model: lm/musicgen_lm
|
||||
- /conditioner: cond_base
|
||||
|
||||
solver: musicgen
|
||||
autocast: true
|
||||
autocast_dtype: float16
|
||||
|
||||
optim:
|
||||
epochs: 100
|
||||
batch_size: 4
|
||||
lr: 1e-4
|
||||
ema: 0.999
|
||||
optimizer: adamw
|
||||
|
||||
dataset:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
train:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
valid:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
|
||||
checkpoint:
|
||||
save_every: 10
|
||||
keep_every_states: null
|
||||
EOF
|
||||
|
||||
# Run fine-tuning
|
||||
dora run solver=musicgen/finetune
|
||||
```
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from audiocraft.models import MusicGen
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Get the language model component
|
||||
lm = model.lm
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
lm = get_peft_model(lm, lora_config)
|
||||
lm.print_trainable_parameters()
|
||||
```
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
### DataParallel
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Wrap LM with DataParallel
|
||||
if torch.cuda.device_count() > 1:
|
||||
model.lm = nn.DataParallel(model.lm)
|
||||
|
||||
model.to("cuda")
|
||||
```
|
||||
|
||||
### DistributedDataParallel
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup(rank, world_size):
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.lm = model.lm.to(rank)
|
||||
model.lm = DDP(model.lm, device_ids=[rank])
|
||||
|
||||
# Training loop
|
||||
# ...
|
||||
|
||||
dist.destroy_process_group()
|
||||
```
|
||||
|
||||
## Custom Conditioning
|
||||
|
||||
### Adding new conditioners
|
||||
|
||||
```python
|
||||
from audiocraft.modules.conditioners import BaseConditioner
|
||||
import torch
|
||||
|
||||
class CustomConditioner(BaseConditioner):
|
||||
"""Custom conditioner for additional control signals."""
|
||||
|
||||
def __init__(self, dim, output_dim):
|
||||
super().__init__(dim, output_dim)
|
||||
self.embed = torch.nn.Linear(dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
def tokenize(self, x):
|
||||
# Tokenize input for conditioning
|
||||
return x
|
||||
|
||||
# Use with MusicGen
|
||||
from audiocraft.models.builders import get_lm_model
|
||||
|
||||
# Modify model config to include custom conditioner
|
||||
# This requires editing the model configuration
|
||||
```
|
||||
|
||||
### Melody conditioning internals
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
|
||||
import torch
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Access chroma extractor
|
||||
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
|
||||
|
||||
# Manual chroma extraction
|
||||
def extract_chroma(audio, sr):
|
||||
"""Extract chroma features from audio."""
|
||||
import librosa
|
||||
|
||||
# Compute chroma
|
||||
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
|
||||
|
||||
return torch.from_numpy(chroma).float()
|
||||
|
||||
# Use extracted chroma for conditioning
|
||||
chroma = extract_chroma(melody_audio, sample_rate)
|
||||
```
|
||||
|
||||
## EnCodec Deep Dive
|
||||
|
||||
### Custom compression settings
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
|
||||
# Load EnCodec
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Access codec parameters
|
||||
print(f"Sample rate: {encodec.sample_rate}")
|
||||
print(f"Channels: {encodec.channels}")
|
||||
print(f"Cardinality: {encodec.cardinality}") # Codebook size
|
||||
print(f"Num codebooks: {encodec.num_codebooks}")
|
||||
print(f"Frame rate: {encodec.frame_rate}")
|
||||
|
||||
# Encode with specific bandwidth
|
||||
# Lower bandwidth = more compression, lower quality
|
||||
encodec.set_target_bandwidth(6.0) # 6 kbps
|
||||
|
||||
audio = torch.randn(1, 1, 32000) # 1 second
|
||||
encoded = encodec.encode(audio)
|
||||
decoded = encodec.decode(encoded[0])
|
||||
```
|
||||
|
||||
### Streaming encoding
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.models import CompressionModel
|
||||
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
def encode_streaming(audio_stream, chunk_size=32000):
|
||||
"""Encode audio in streaming fashion."""
|
||||
all_codes = []
|
||||
|
||||
for chunk in audio_stream:
|
||||
# Ensure chunk is right shape
|
||||
if chunk.dim() == 1:
|
||||
chunk = chunk.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
codes = encodec.encode(chunk)[0]
|
||||
all_codes.append(codes)
|
||||
|
||||
return torch.cat(all_codes, dim=-1)
|
||||
|
||||
def decode_streaming(codes_stream, output_stream):
|
||||
"""Decode codes in streaming fashion."""
|
||||
for codes in codes_stream:
|
||||
with torch.no_grad():
|
||||
audio = encodec.decode(codes)
|
||||
output_stream.write(audio.cpu().numpy())
|
||||
```
|
||||
|
||||
## MultiBand Diffusion
|
||||
|
||||
### Using MBD for enhanced quality
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen, MultiBandDiffusion
|
||||
|
||||
# Load MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Load MultiBand Diffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate with standard decoder
|
||||
descriptions = ["epic orchestral music"]
|
||||
wav_standard = model.generate(descriptions)
|
||||
|
||||
# Generate tokens and use MBD decoder
|
||||
with torch.no_grad():
|
||||
# Get tokens
|
||||
gen_tokens = model.generate_tokens(descriptions)
|
||||
|
||||
# Decode with MBD
|
||||
wav_mbd = mbd.tokens_to_wav(gen_tokens)
|
||||
|
||||
# Compare quality
|
||||
print(f"Standard shape: {wav_standard.shape}")
|
||||
print(f"MBD shape: {wav_mbd.shape}")
|
||||
```
|
||||
|
||||
## API Server Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import io
|
||||
import base64
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
cfg_coef: float = 3.0
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
audio_base64: str
|
||||
sample_rate: int
|
||||
duration: float
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
model.set_generation_params(
|
||||
duration=min(request.duration, 30),
|
||||
temperature=request.temperature,
|
||||
cfg_coef=request.cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([request.prompt])
|
||||
|
||||
# Convert to bytes
|
||||
buffer = io.BytesIO()
|
||||
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
|
||||
buffer.seek(0)
|
||||
|
||||
audio_base64 = base64.b64encode(buffer.read()).decode()
|
||||
|
||||
return GenerateResponse(
|
||||
audio_base64=audio_base64,
|
||||
sample_rate=32000,
|
||||
duration=wav.shape[-1] / 32000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
|
||||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Batch processing service
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenService:
|
||||
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def generate_async(self, prompt, duration=10):
|
||||
"""Async generation with thread pool."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _generate():
|
||||
with torch.no_grad():
|
||||
self.model.set_generation_params(duration=duration)
|
||||
return self.model.generate([prompt])
|
||||
|
||||
# Run in thread pool
|
||||
wav = await loop.run_in_executor(self.executor, _generate)
|
||||
return wav[0].cpu()
|
||||
|
||||
async def generate_batch_async(self, prompts, duration=10):
|
||||
"""Process multiple prompts concurrently."""
|
||||
tasks = [self.generate_async(p, duration) for p in prompts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Usage
|
||||
service = MusicGenService()
|
||||
|
||||
async def main():
|
||||
prompts = ["jazz piano", "rock guitar", "electronic beats"]
|
||||
results = await service.generate_batch_async(prompts)
|
||||
return results
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### LangChain tool
|
||||
|
||||
```python
|
||||
from langchain.tools import BaseTool
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import tempfile
|
||||
|
||||
class MusicGeneratorTool(BaseTool):
|
||||
name = "music_generator"
|
||||
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
self.model.set_generation_params(duration=15)
|
||||
|
||||
def _run(self, description: str) -> str:
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([description])
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
|
||||
return f"Generated music saved to: {f.name}"
|
||||
|
||||
async def _arun(self, description: str) -> str:
|
||||
return self._run(description)
|
||||
```
|
||||
|
||||
### Gradio with advanced controls
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
models = {}
|
||||
|
||||
def load_model(model_size):
|
||||
if model_size not in models:
|
||||
model_name = f"facebook/musicgen-{model_size}"
|
||||
models[model_size] = MusicGen.get_pretrained(model_name)
|
||||
return models[model_size]
|
||||
|
||||
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
|
||||
model = load_model(model_size)
|
||||
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save
|
||||
path = "output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
gr.Textbox(label="Prompt", lines=3),
|
||||
gr.Slider(1, 30, value=10, label="Duration (s)"),
|
||||
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
|
||||
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
|
||||
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Advanced",
|
||||
allow_flagging="never"
|
||||
)
|
||||
|
||||
demo.launch(share=True)
|
||||
```
|
||||
|
||||
## Audio Processing Pipeline
|
||||
|
||||
### Post-processing chain
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T
|
||||
import numpy as np
|
||||
|
||||
class AudioPostProcessor:
|
||||
def __init__(self, sample_rate=32000):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def normalize(self, audio, target_db=-14.0):
|
||||
"""Normalize audio to target loudness."""
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
def fade_in_out(self, audio, fade_duration=0.1):
|
||||
"""Apply fade in/out."""
|
||||
fade_samples = int(fade_duration * self.sample_rate)
|
||||
|
||||
# Create fade curves
|
||||
fade_in = torch.linspace(0, 1, fade_samples)
|
||||
fade_out = torch.linspace(1, 0, fade_samples)
|
||||
|
||||
# Apply fades
|
||||
audio[..., :fade_samples] *= fade_in
|
||||
audio[..., -fade_samples:] *= fade_out
|
||||
|
||||
return audio
|
||||
|
||||
def apply_reverb(self, audio, decay=0.5):
|
||||
"""Apply simple reverb effect."""
|
||||
impulse = torch.zeros(int(self.sample_rate * 0.5))
|
||||
impulse[0] = 1.0
|
||||
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
|
||||
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
|
||||
|
||||
# Convolve
|
||||
audio = torch.nn.functional.conv1d(
|
||||
audio.unsqueeze(0),
|
||||
impulse.unsqueeze(0).unsqueeze(0),
|
||||
padding=len(impulse) // 2
|
||||
).squeeze(0)
|
||||
|
||||
return audio
|
||||
|
||||
def process(self, audio):
|
||||
"""Full processing pipeline."""
|
||||
audio = self.normalize(audio)
|
||||
audio = self.fade_in_out(audio)
|
||||
return audio
|
||||
|
||||
# Usage with MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
wav = model.generate(["chill ambient music"])
|
||||
processor = AudioPostProcessor()
|
||||
wav_processed = processor.process(wav[0].cpu())
|
||||
|
||||
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Audio quality metrics
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.metrics import CLAPTextConsistencyMetric
|
||||
from audiocraft.data.audio import audio_read
|
||||
|
||||
def evaluate_generation(audio_path, text_prompt):
|
||||
"""Evaluate generated audio quality."""
|
||||
# Load audio
|
||||
wav, sr = audio_read(audio_path)
|
||||
|
||||
# CLAP consistency (text-audio alignment)
|
||||
clap_metric = CLAPTextConsistencyMetric()
|
||||
clap_score = clap_metric.compute(wav, [text_prompt])
|
||||
|
||||
return {
|
||||
"clap_score": clap_score,
|
||||
"duration": wav.shape[-1] / sr
|
||||
}
|
||||
|
||||
# Batch evaluation
|
||||
def evaluate_batch(generations):
|
||||
"""Evaluate multiple generations."""
|
||||
results = []
|
||||
for gen in generations:
|
||||
result = evaluate_generation(gen["path"], gen["prompt"])
|
||||
result["prompt"] = gen["prompt"]
|
||||
results.append(result)
|
||||
|
||||
# Aggregate
|
||||
avg_clap = sum(r["clap_score"] for r in results) / len(results)
|
||||
return {
|
||||
"individual": results,
|
||||
"average_clap": avg_clap
|
||||
}
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
### MusicGen variants benchmark
|
||||
|
||||
| Model | CLAP Score | Generation Time (10s) | VRAM |
|
||||
|-------|------------|----------------------|------|
|
||||
| musicgen-small | 0.35 | ~5s | 2GB |
|
||||
| musicgen-medium | 0.42 | ~15s | 4GB |
|
||||
| musicgen-large | 0.48 | ~30s | 8GB |
|
||||
| musicgen-melody | 0.45 | ~15s | 4GB |
|
||||
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
|
||||
|
||||
### Prompt engineering tips
|
||||
|
||||
```python
|
||||
# Good prompts - specific and descriptive
|
||||
good_prompts = [
|
||||
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
|
||||
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
|
||||
"funky disco groove with slap bass, brass section, and rhythmic guitar"
|
||||
]
|
||||
|
||||
# Bad prompts - too vague
|
||||
bad_prompts = [
|
||||
"nice music",
|
||||
"song",
|
||||
"good beat"
|
||||
]
|
||||
|
||||
# Structure: [mood] [genre] with [instruments] at [tempo/style]
|
||||
```
|
||||
504
skills/mlops/audiocraft/references/troubleshooting.md
Normal file
504
skills/mlops/audiocraft/references/troubleshooting.md
Normal file
|
|
@ -0,0 +1,504 @@
|
|||
# AudioCraft Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# Or from GitHub
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Verify installation
|
||||
python -c "from audiocraft.models import MusicGen; print('OK')"
|
||||
```
|
||||
|
||||
### FFmpeg not found
|
||||
|
||||
**Error**: `RuntimeError: ffmpeg not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows (using conda)
|
||||
conda install -c conda-forge ffmpeg
|
||||
|
||||
# Verify
|
||||
ffmpeg -version
|
||||
```
|
||||
|
||||
### PyTorch CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
python -c "import torch; print(torch.version.cuda)"
|
||||
|
||||
# Install matching PyTorch
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For CUDA 11.8
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
### xformers issues
|
||||
|
||||
**Error**: `ImportError: xformers` related errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install xformers for memory efficiency
|
||||
pip install xformers
|
||||
|
||||
# Or disable xformers
|
||||
export AUDIOCRAFT_USE_XFORMERS=0
|
||||
|
||||
# In Python
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
|
||||
from audiocraft.models import MusicGen
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Force CPU loading first
|
||||
import torch
|
||||
device = "cpu"
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
|
||||
model = model.to("cuda")
|
||||
|
||||
# Use HuggingFace with device_map
|
||||
from transformers import MusicgenForConditionalGeneration
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||
"facebook/musicgen-small",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Download failures
|
||||
|
||||
**Error**: Connection errors or incomplete downloads
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Set cache directory
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
|
||||
|
||||
# Or for HuggingFace
|
||||
os.environ["HF_HOME"] = "/path/to/hf_cache"
|
||||
|
||||
# Resume download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("facebook/musicgen-small", resume_download=True)
|
||||
|
||||
# Use local files
|
||||
model = MusicGen.get_pretrained('/local/path/to/model')
|
||||
```
|
||||
|
||||
### Wrong model type
|
||||
|
||||
**Error**: Loading wrong model for task
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# For text-to-music: use MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# For text-to-sound: use AudioGen
|
||||
from audiocraft.models import AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
# For melody conditioning: use melody variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# For stereo: use stereo variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Empty or silent output
|
||||
|
||||
**Problem**: Generated audio is silent or very quiet
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check output
|
||||
wav = model.generate(["upbeat music"])
|
||||
print(f"Shape: {wav.shape}")
|
||||
print(f"Max amplitude: {wav.abs().max().item()}")
|
||||
print(f"Mean amplitude: {wav.abs().mean().item()}")
|
||||
|
||||
# If too quiet, normalize
|
||||
def normalize_audio(audio, target_db=-14.0):
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
wav_normalized = normalize_audio(wav)
|
||||
```
|
||||
|
||||
### Poor quality output
|
||||
|
||||
**Problem**: Generated music sounds bad or noisy
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use larger model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-large')
|
||||
|
||||
# Adjust generation parameters
|
||||
model.set_generation_params(
|
||||
duration=15,
|
||||
top_k=250, # Increase for more diversity
|
||||
temperature=0.8, # Lower for more focused output
|
||||
cfg_coef=4.0 # Increase for better text adherence
|
||||
)
|
||||
|
||||
# Use better prompts
|
||||
# Bad: "music"
|
||||
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
|
||||
|
||||
# Try MultiBand Diffusion
|
||||
from audiocraft.models import MultiBandDiffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
tokens = model.generate_tokens(["prompt"])
|
||||
wav = mbd.tokens_to_wav(tokens)
|
||||
```
|
||||
|
||||
### Generation too short
|
||||
|
||||
**Problem**: Audio shorter than expected
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check duration setting
|
||||
model.set_generation_params(duration=30) # Set before generate
|
||||
|
||||
# Verify in generation
|
||||
print(f"Duration setting: {model.generation_params}")
|
||||
|
||||
# Check output shape
|
||||
wav = model.generate(["prompt"])
|
||||
actual_duration = wav.shape[-1] / 32000
|
||||
print(f"Actual duration: {actual_duration}s")
|
||||
|
||||
# Note: max duration is typically 30s
|
||||
```
|
||||
|
||||
### Melody conditioning fails
|
||||
|
||||
**Error**: Issues with melody-conditioned generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load melody model (not base model)
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Load and prepare melody
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Resample to model sample rate if needed
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
melody = resampler(melody)
|
||||
|
||||
# Ensure correct shape [batch, channels, samples]
|
||||
if melody.dim() == 1:
|
||||
melody = melody.unsqueeze(0).unsqueeze(0)
|
||||
elif melody.dim() == 2:
|
||||
melody = melody.unsqueeze(0)
|
||||
|
||||
# Convert stereo to mono
|
||||
if melody.shape[1] > 1:
|
||||
melody = melody.mean(dim=1, keepdim=True)
|
||||
|
||||
# Generate with melody
|
||||
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
|
||||
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Clear cache before generation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Generate one at a time
|
||||
for prompt in prompts:
|
||||
wav = model.generate([prompt])
|
||||
save_audio(wav)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use CPU for very large generations
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
|
||||
```
|
||||
|
||||
### Memory leak during batch processing
|
||||
|
||||
**Problem**: Memory grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(model, prompts):
|
||||
results = []
|
||||
|
||||
for prompt in prompts:
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
results.append(wav.cpu())
|
||||
|
||||
# Cleanup
|
||||
del wav
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
# Use context manager
|
||||
with torch.inference_mode():
|
||||
wav = model.generate(["prompt"])
|
||||
```
|
||||
|
||||
## Audio Format Issues
|
||||
|
||||
### Wrong sample rate
|
||||
|
||||
**Problem**: Audio plays at wrong speed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
|
||||
# MusicGen outputs at 32kHz
|
||||
sample_rate = 32000
|
||||
|
||||
# AudioGen outputs at 16kHz
|
||||
sample_rate = 16000
|
||||
|
||||
# Always use correct rate when saving
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
|
||||
|
||||
# Resample if needed
|
||||
resampler = torchaudio.transforms.Resample(32000, 44100)
|
||||
wav_resampled = resampler(wav)
|
||||
```
|
||||
|
||||
### Stereo/mono mismatch
|
||||
|
||||
**Problem**: Wrong number of channels
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check model type
|
||||
print(f"Audio channels: {wav.shape}")
|
||||
# Mono: [batch, 1, samples]
|
||||
# Stereo: [batch, 2, samples]
|
||||
|
||||
# Convert mono to stereo
|
||||
if wav.shape[1] == 1:
|
||||
wav_stereo = wav.repeat(1, 2, 1)
|
||||
|
||||
# Convert stereo to mono
|
||||
if wav.shape[1] == 2:
|
||||
wav_mono = wav.mean(dim=1, keepdim=True)
|
||||
|
||||
# Use stereo model for stereo output
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
### Clipping and distortion
|
||||
|
||||
**Problem**: Audio has clipping or distortion
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check for clipping
|
||||
max_val = wav.abs().max().item()
|
||||
print(f"Max amplitude: {max_val}")
|
||||
|
||||
# Normalize to prevent clipping
|
||||
if max_val > 1.0:
|
||||
wav = wav / max_val
|
||||
|
||||
# Apply soft clipping
|
||||
def soft_clip(x, threshold=0.9):
|
||||
return torch.tanh(x / threshold) * threshold
|
||||
|
||||
wav_clipped = soft_clip(wav)
|
||||
|
||||
# Lower temperature during generation
|
||||
model.set_generation_params(temperature=0.7) # More controlled
|
||||
```
|
||||
|
||||
## HuggingFace Transformers Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with MusicgenProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
# Load matching processor and model
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
# Ensure inputs are on same device
|
||||
inputs = processor(
|
||||
text=["prompt"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
# Check processor configuration
|
||||
print(processor.tokenizer)
|
||||
print(processor.feature_extractor)
|
||||
```
|
||||
|
||||
### Generation parameter errors
|
||||
|
||||
**Error**: Invalid generation parameters
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# HuggingFace uses different parameter names
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True, # Enable sampling
|
||||
guidance_scale=3.0, # CFG (not cfg_coef)
|
||||
max_new_tokens=256, # Token limit (not duration)
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Calculate tokens from duration
|
||||
# ~50 tokens per second
|
||||
duration_seconds = 10
|
||||
max_tokens = duration_seconds * 50
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Use GPU
|
||||
model.to("cuda")
|
||||
|
||||
# Enable flash attention if available
|
||||
# (requires compatible hardware)
|
||||
|
||||
# Batch multiple prompts
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
wav = model.generate(prompts) # Single batch is faster than loop
|
||||
|
||||
# Use compile (PyTorch 2.0+)
|
||||
model.lm = torch.compile(model.lm)
|
||||
```
|
||||
|
||||
### CPU fallback
|
||||
|
||||
**Problem**: Generation running on CPU instead of GPU
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check CUDA availability
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Explicitly move to GPU
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.to("cuda")
|
||||
|
||||
# Verify model device
|
||||
print(f"Model device: {next(model.lm.parameters()).device}")
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
|
||||
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
|
||||
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
|
||||
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
|
||||
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
|
||||
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2306.05284
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version
|
||||
- CUDA version
|
||||
- AudioCraft version: `pip show audiocraft`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Hardware (GPU model, VRAM)
|
||||
81
skills/mlops/code-review/SKILL.md
Normal file
81
skills/mlops/code-review/SKILL.md
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
---
|
||||
name: code-review
|
||||
description: Guidelines for performing thorough code reviews with security and quality focus
|
||||
---
|
||||
|
||||
# Code Review Skill
|
||||
|
||||
Use this skill when reviewing code changes, pull requests, or auditing existing code.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
### 1. Security First
|
||||
- [ ] No hardcoded secrets, API keys, or credentials
|
||||
- [ ] Input validation on all user-provided data
|
||||
- [ ] SQL queries use parameterized statements (no string concatenation)
|
||||
- [ ] File operations validate paths (no path traversal)
|
||||
- [ ] Authentication/authorization checks present where needed
|
||||
|
||||
### 2. Error Handling
|
||||
- [ ] All external calls (API, DB, file) have try/catch
|
||||
- [ ] Errors are logged with context (but no sensitive data)
|
||||
- [ ] User-facing errors are helpful but don't leak internals
|
||||
- [ ] Resources are cleaned up in finally blocks or context managers
|
||||
|
||||
### 3. Code Quality
|
||||
- [ ] Functions do one thing and are reasonably sized (<50 lines ideal)
|
||||
- [ ] Variable names are descriptive (no single letters except loops)
|
||||
- [ ] No commented-out code left behind
|
||||
- [ ] Complex logic has explanatory comments
|
||||
- [ ] No duplicate code (DRY principle)
|
||||
|
||||
### 4. Testing Considerations
|
||||
- [ ] Edge cases handled (empty inputs, nulls, boundaries)
|
||||
- [ ] Happy path and error paths both work
|
||||
- [ ] New code has corresponding tests (if test suite exists)
|
||||
|
||||
## Review Response Format
|
||||
|
||||
When providing review feedback, structure it as:
|
||||
|
||||
```
|
||||
## Summary
|
||||
[1-2 sentence overall assessment]
|
||||
|
||||
## Critical Issues (Must Fix)
|
||||
- Issue 1: [description + suggested fix]
|
||||
- Issue 2: ...
|
||||
|
||||
## Suggestions (Nice to Have)
|
||||
- Suggestion 1: [description]
|
||||
|
||||
## Questions
|
||||
- [Any clarifying questions about intent]
|
||||
```
|
||||
|
||||
## Common Patterns to Flag
|
||||
|
||||
### Python
|
||||
```python
|
||||
# Bad: SQL injection risk
|
||||
cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
# Good: Parameterized query
|
||||
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
```
|
||||
|
||||
### JavaScript
|
||||
```javascript
|
||||
// Bad: XSS risk
|
||||
element.innerHTML = userInput;
|
||||
|
||||
// Good: Safe text content
|
||||
element.textContent = userInput;
|
||||
```
|
||||
|
||||
## Tone Guidelines
|
||||
|
||||
- Be constructive, not critical
|
||||
- Explain *why* something is an issue, not just *what*
|
||||
- Offer solutions, not just problems
|
||||
- Acknowledge good patterns you see
|
||||
224
skills/mlops/faiss/SKILL.md
Normal file
224
skills/mlops/faiss/SKILL.md
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
---
|
||||
name: faiss
|
||||
description: Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or when you need pure similarity search without metadata. Best for high-performance applications.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [faiss-cpu, faiss-gpu, numpy]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [RAG, FAISS, Similarity Search, Vector Search, Facebook AI, GPU Acceleration, Billion-Scale, K-NN, HNSW, High Performance, Large Scale]
|
||||
|
||||
---
|
||||
|
||||
# FAISS - Efficient Similarity Search
|
||||
|
||||
Facebook AI's library for billion-scale vector similarity search.
|
||||
|
||||
## When to use FAISS
|
||||
|
||||
**Use FAISS when:**
|
||||
- Need fast similarity search on large vector datasets (millions/billions)
|
||||
- GPU acceleration required
|
||||
- Pure vector similarity (no metadata filtering needed)
|
||||
- High throughput, low latency critical
|
||||
- Offline/batch processing of embeddings
|
||||
|
||||
**Metrics**:
|
||||
- **31,700+ GitHub stars**
|
||||
- Meta/Facebook AI Research
|
||||
- **Handles billions of vectors**
|
||||
- **C++** with Python bindings
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Chroma/Pinecone**: Need metadata filtering
|
||||
- **Weaviate**: Need full database features
|
||||
- **Annoy**: Simpler, fewer features
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# CPU only
|
||||
pip install faiss-cpu
|
||||
|
||||
# GPU support
|
||||
pip install faiss-gpu
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
# Create sample data (1000 vectors, 128 dimensions)
|
||||
d = 128
|
||||
nb = 1000
|
||||
vectors = np.random.random((nb, d)).astype('float32')
|
||||
|
||||
# Create index
|
||||
index = faiss.IndexFlatL2(d) # L2 distance
|
||||
index.add(vectors) # Add vectors
|
||||
|
||||
# Search
|
||||
k = 5 # Find 5 nearest neighbors
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
|
||||
print(f"Nearest neighbors: {indices}")
|
||||
print(f"Distances: {distances}")
|
||||
```
|
||||
|
||||
## Index types
|
||||
|
||||
### 1. Flat (exact search)
|
||||
|
||||
```python
|
||||
# L2 (Euclidean) distance
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Inner product (cosine similarity if normalized)
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Slowest, most accurate
|
||||
```
|
||||
|
||||
### 2. IVF (inverted file) - Fast approximate
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# IVF index with 100 clusters
|
||||
nlist = 100
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 3. HNSW (Hierarchical NSW) - Best quality/speed
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# No training needed
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 4. Product Quantization - Memory efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers
|
||||
nbits = 8
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train and add
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
## Save and load
|
||||
|
||||
```python
|
||||
# Save index
|
||||
faiss.write_index(index, "large.index")
|
||||
|
||||
# Load index
|
||||
index = faiss.read_index("large.index")
|
||||
|
||||
# Continue using
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
```python
|
||||
# Single GPU
|
||||
res = faiss.StandardGpuResources()
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Multi-GPU
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# 10-100× faster than CPU
|
||||
```
|
||||
|
||||
## LangChain integration
|
||||
|
||||
```python
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
# Create FAISS vector store
|
||||
vectorstore = FAISS.from_documents(docs, OpenAIEmbeddings())
|
||||
|
||||
# Save
|
||||
vectorstore.save_local("faiss_index")
|
||||
|
||||
# Load
|
||||
vectorstore = FAISS.load_local(
|
||||
"faiss_index",
|
||||
OpenAIEmbeddings(),
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
|
||||
# Search
|
||||
results = vectorstore.similarity_search("query", k=5)
|
||||
```
|
||||
|
||||
## LlamaIndex integration
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
import faiss
|
||||
|
||||
# Create FAISS index
|
||||
d = 1536
|
||||
faiss_index = faiss.IndexFlatL2(d)
|
||||
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Choose right index type** - Flat for <10K, IVF for 10K-1M, HNSW for quality
|
||||
2. **Normalize for cosine** - Use IndexFlatIP with normalized vectors
|
||||
3. **Use GPU for large datasets** - 10-100× faster
|
||||
4. **Save trained indices** - Training is expensive
|
||||
5. **Tune nprobe/ef_search** - Balance speed/accuracy
|
||||
6. **Monitor memory** - PQ for large datasets
|
||||
7. **Batch queries** - Better GPU utilization
|
||||
|
||||
## Performance
|
||||
|
||||
| Index Type | Build Time | Search Time | Memory | Accuracy |
|
||||
|------------|------------|-------------|--------|----------|
|
||||
| Flat | Fast | Slow | High | 100% |
|
||||
| IVF | Medium | Fast | Medium | 95-99% |
|
||||
| HNSW | Slow | Fastest | High | 99% |
|
||||
| PQ | Medium | Fast | Low | 90-95% |
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/faiss ⭐ 31,700+
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
280
skills/mlops/faiss/references/index_types.md
Normal file
280
skills/mlops/faiss/references/index_types.md
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
# FAISS Index Types Guide
|
||||
|
||||
Complete guide to choosing and using FAISS index types.
|
||||
|
||||
## Index selection guide
|
||||
|
||||
| Dataset Size | Index Type | Training | Accuracy | Speed |
|
||||
|--------------|------------|----------|----------|-------|
|
||||
| < 10K | Flat | No | 100% | Slow |
|
||||
| 10K-1M | IVF | Yes | 95-99% | Fast |
|
||||
| 1M-10M | HNSW | No | 99% | Fastest |
|
||||
| > 10M | IVF+PQ | Yes | 90-95% | Fast, low memory |
|
||||
|
||||
## Flat indices (exact search)
|
||||
|
||||
### IndexFlatL2 - L2 (Euclidean) distance
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
d = 128 # Dimension
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Add vectors
|
||||
vectors = np.random.random((1000, d)).astype('float32')
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
k = 5
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset < 10,000 vectors
|
||||
- Need 100% accuracy
|
||||
- Serving as baseline
|
||||
|
||||
### IndexFlatIP - Inner product (cosine similarity)
|
||||
|
||||
```python
|
||||
# For cosine similarity, normalize vectors first
|
||||
import faiss
|
||||
|
||||
d = 128
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Normalize vectors (required for cosine similarity)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
faiss.normalize_L2(query)
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Need cosine similarity
|
||||
- Recommendation systems
|
||||
- Text embeddings
|
||||
|
||||
## IVF indices (inverted file)
|
||||
|
||||
### IndexIVFFlat - Cluster-based search
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# Create IVF index with 100 clusters
|
||||
nlist = 100 # Number of clusters
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10 # Search 10 closest clusters
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `nlist`: Number of clusters (√N to 4√N recommended)
|
||||
- `nprobe`: Clusters to search (1-nlist, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Dataset 10K-1M vectors
|
||||
- Need fast approximate search
|
||||
- Can afford training time
|
||||
|
||||
### Tuning nprobe
|
||||
|
||||
```python
|
||||
# Test different nprobe values
|
||||
for nprobe in [1, 5, 10, 20, 50]:
|
||||
index.nprobe = nprobe
|
||||
distances, indices = index.search(query, k)
|
||||
# Measure recall/speed trade-off
|
||||
```
|
||||
|
||||
**Guidelines:**
|
||||
- `nprobe=1`: Fastest, ~50% recall
|
||||
- `nprobe=10`: Good balance, ~95% recall
|
||||
- `nprobe=nlist`: Exact search (same as Flat)
|
||||
|
||||
## HNSW indices (graph-based)
|
||||
|
||||
### IndexHNSWFlat - Hierarchical NSW
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer (16-64)
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# Optional: Set ef_construction (build time parameter)
|
||||
index.hnsw.efConstruction = 40 # Higher = better quality, slower build
|
||||
|
||||
# Add vectors (no training needed!)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.hnsw.efSearch = 16 # Search time parameter
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `M`: Connections per layer (16-64, default 32)
|
||||
- `efConstruction`: Build quality (40-200, higher = better)
|
||||
- `efSearch`: Search quality (16-512, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Need best quality approximate search
|
||||
- Can afford higher memory (more connections)
|
||||
- Dataset 1M-10M vectors
|
||||
|
||||
## PQ indices (product quantization)
|
||||
|
||||
### IndexPQ - Memory-efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers (divides d)
|
||||
nbits = 8 # Bits per subquantizer
|
||||
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `m`: Subquantizers (d must be divisible by m)
|
||||
- `nbits`: Bits per code (8 or 16)
|
||||
|
||||
**Memory savings:**
|
||||
- Original: d × 4 bytes (float32)
|
||||
- PQ: m bytes
|
||||
- Compression ratio: 4d/m
|
||||
|
||||
**Use when:**
|
||||
- Limited memory
|
||||
- Large datasets (> 10M vectors)
|
||||
- Can accept ~90-95% accuracy
|
||||
|
||||
### IndexIVFPQ - IVF + PQ combined
|
||||
|
||||
```python
|
||||
# Best for very large datasets
|
||||
nlist = 4096
|
||||
m = 8
|
||||
nbits = 8
|
||||
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)
|
||||
|
||||
# Train
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.nprobe = 32
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset > 10M vectors
|
||||
- Need fast search + low memory
|
||||
- Can accept 90-95% accuracy
|
||||
|
||||
## GPU indices
|
||||
|
||||
### Single GPU
|
||||
|
||||
```python
|
||||
import faiss
|
||||
|
||||
# Create CPU index
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
|
||||
# Move to GPU
|
||||
res = faiss.StandardGpuResources() # GPU resources
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Use normally
|
||||
index_gpu.add(vectors)
|
||||
distances, indices = index_gpu.search(query, k)
|
||||
```
|
||||
|
||||
### Multi-GPU
|
||||
|
||||
```python
|
||||
# Use all available GPUs
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# Or specific GPUs
|
||||
gpus = [0, 1, 2, 3] # Use GPUs 0-3
|
||||
index_gpu = faiss.index_cpu_to_gpus_list(index_cpu, gpus)
|
||||
```
|
||||
|
||||
**Speedup:**
|
||||
- Single GPU: 10-50× faster than CPU
|
||||
- Multi-GPU: Near-linear scaling
|
||||
|
||||
## Index factory
|
||||
|
||||
```python
|
||||
# Easy index creation with string descriptors
|
||||
index = faiss.index_factory(d, "IVF100,Flat")
|
||||
index = faiss.index_factory(d, "HNSW32")
|
||||
index = faiss.index_factory(d, "IVF4096,PQ8")
|
||||
|
||||
# Train and use
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
**Common descriptors:**
|
||||
- `"Flat"`: Exact search
|
||||
- `"IVF100,Flat"`: IVF with 100 clusters
|
||||
- `"HNSW32"`: HNSW with M=32
|
||||
- `"IVF4096,PQ8"`: IVF + PQ compression
|
||||
|
||||
## Performance comparison
|
||||
|
||||
### Search speed (1M vectors, k=10)
|
||||
|
||||
| Index | Build Time | Search Time | Memory | Recall |
|
||||
|-------|------------|-------------|--------|--------|
|
||||
| Flat | 0s | 50ms | 512 MB | 100% |
|
||||
| IVF100 | 5s | 2ms | 512 MB | 95% |
|
||||
| HNSW32 | 60s | 1ms | 1GB | 99% |
|
||||
| IVF4096+PQ8 | 30s | 3ms | 32 MB | 90% |
|
||||
|
||||
*CPU (16 cores), 128-dim vectors*
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with Flat** - Baseline for comparison
|
||||
2. **Use IVF for medium datasets** - Good balance
|
||||
3. **Use HNSW for best quality** - If memory allows
|
||||
4. **Add PQ for memory savings** - Large datasets
|
||||
5. **GPU for > 100K vectors** - 10-50× speedup
|
||||
6. **Tune nprobe/efSearch** - Trade-off speed/accuracy
|
||||
7. **Train on representative data** - Better clustering
|
||||
8. **Save trained indices** - Avoid retraining
|
||||
|
||||
## Resources
|
||||
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **Paper**: https://arxiv.org/abs/1702.08734
|
||||
370
skills/mlops/flash-attention/SKILL.md
Normal file
370
skills/mlops/flash-attention/SKILL.md
Normal file
|
|
@ -0,0 +1,370 @@
|
|||
---
|
||||
name: optimizing-attention-flash
|
||||
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [flash-attn, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
|
||||
|
||||
---
|
||||
|
||||
# Flash Attention - Fast Memory-Efficient Attention
|
||||
|
||||
## Quick start
|
||||
|
||||
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
|
||||
|
||||
**PyTorch native (easiest, PyTorch 2.2+)**:
|
||||
```python
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
|
||||
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Automatically uses Flash Attention if available
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**flash-attn library (more features)**:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q, k, v: [batch, seqlen, nheads, headdim]
|
||||
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Enable in existing PyTorch model
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Flash Attention Integration:
|
||||
- [ ] Step 1: Check PyTorch version (≥2.2)
|
||||
- [ ] Step 2: Enable Flash Attention backend
|
||||
- [ ] Step 3: Verify speedup with profiling
|
||||
- [ ] Step 4: Test accuracy matches baseline
|
||||
```
|
||||
|
||||
**Step 1: Check PyTorch version**
|
||||
|
||||
```bash
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
# Should be ≥2.2.0
|
||||
```
|
||||
|
||||
If <2.2, upgrade:
|
||||
```bash
|
||||
pip install --upgrade torch
|
||||
```
|
||||
|
||||
**Step 2: Enable Flash Attention backend**
|
||||
|
||||
Replace standard attention:
|
||||
```python
|
||||
# Before (standard attention)
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
|
||||
out = attn_weights @ v
|
||||
|
||||
# After (Flash Attention)
|
||||
import torch.nn.functional as F
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
```
|
||||
|
||||
Force Flash Attention backend:
|
||||
```python
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=False,
|
||||
enable_mem_efficient=False
|
||||
):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**Step 3: Verify speedup with profiling**
|
||||
|
||||
```python
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
def test_attention(use_flash):
|
||||
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
if use_flash:
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True):
|
||||
return F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
|
||||
return attn @ v
|
||||
|
||||
# Benchmark
|
||||
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
|
||||
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
|
||||
|
||||
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
|
||||
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
|
||||
```
|
||||
|
||||
Expected: 2-4x speedup for sequences >512 tokens.
|
||||
|
||||
**Step 4: Test accuracy matches baseline**
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Flash Attention
|
||||
out_flash = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
# Standard attention
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
|
||||
out_standard = attn_weights @ v
|
||||
|
||||
# Check difference
|
||||
diff = (out_flash - out_standard).abs().max()
|
||||
print(f"Max difference: {diff:.6f}")
|
||||
# Should be <1e-3 for float16
|
||||
```
|
||||
|
||||
### Workflow 2: Use flash-attn library for advanced features
|
||||
|
||||
For multi-query attention, sliding window, or H100 FP8.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
flash-attn Library Setup:
|
||||
- [ ] Step 1: Install flash-attn library
|
||||
- [ ] Step 2: Modify attention code
|
||||
- [ ] Step 3: Enable advanced features
|
||||
- [ ] Step 4: Benchmark performance
|
||||
```
|
||||
|
||||
**Step 1: Install flash-attn library**
|
||||
|
||||
```bash
|
||||
# NVIDIA GPUs (CUDA 12.0+)
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# Verify installation
|
||||
python -c "from flash_attn import flash_attn_func; print('Success')"
|
||||
```
|
||||
|
||||
**Step 2: Modify attention code**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# Input: [batch_size, seq_len, num_heads, head_dim]
|
||||
# Transpose from [batch, heads, seq, dim] if needed
|
||||
q = q.transpose(1, 2) # [batch, seq, heads, dim]
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
dropout_p=0.1,
|
||||
causal=True, # For autoregressive models
|
||||
window_size=(-1, -1), # No sliding window
|
||||
softmax_scale=None # Auto-scale
|
||||
)
|
||||
|
||||
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
|
||||
```
|
||||
|
||||
**Step 3: Enable advanced features**
|
||||
|
||||
Multi-query attention (shared K/V across heads):
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q: [batch, seq, num_q_heads, dim]
|
||||
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
|
||||
out = flash_attn_func(q, k, v) # Automatically handles MQA
|
||||
```
|
||||
|
||||
Sliding window attention (local attention):
|
||||
```python
|
||||
# Only attend to window of 256 tokens before/after
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
window_size=(256, 256), # (left, right) window
|
||||
causal=True
|
||||
)
|
||||
```
|
||||
|
||||
**Step 4: Benchmark performance**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from flash_attn import flash_attn_func
|
||||
import time
|
||||
|
||||
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = flash_attn_func(q, k, v)
|
||||
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
out = flash_attn_func(q, k, v)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
|
||||
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
|
||||
```
|
||||
|
||||
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
|
||||
|
||||
For maximum performance on H100 GPUs.
|
||||
|
||||
```
|
||||
FP8 Setup:
|
||||
- [ ] Step 1: Verify H100 GPU available
|
||||
- [ ] Step 2: Install flash-attn with FP8 support
|
||||
- [ ] Step 3: Convert inputs to FP8
|
||||
- [ ] Step 4: Run with FP8 attention
|
||||
```
|
||||
|
||||
**Step 1: Verify H100 GPU**
|
||||
|
||||
```bash
|
||||
nvidia-smi --query-gpu=name --format=csv
|
||||
# Should show "H100" or "H800"
|
||||
```
|
||||
|
||||
**Step 2: Install flash-attn with FP8 support**
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
# FP8 support included for H100
|
||||
```
|
||||
|
||||
**Step 3: Convert inputs to FP8**
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Convert to float8_e4m3 (FP8)
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
k_fp8 = k.to(torch.float8_e4m3fn)
|
||||
v_fp8 = v.to(torch.float8_e4m3fn)
|
||||
```
|
||||
|
||||
**Step 4: Run with FP8 attention**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# FlashAttention-3 automatically uses FP8 kernels on H100
|
||||
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
|
||||
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Flash Attention when:**
|
||||
- Training transformers with sequences >512 tokens
|
||||
- Running inference with long context (>2K tokens)
|
||||
- GPU memory constrained (OOM with standard attention)
|
||||
- Need 2-4x speedup without accuracy loss
|
||||
- Using PyTorch 2.2+ or can install flash-attn
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
|
||||
- **xFormers**: Need more attention variants (not just speed)
|
||||
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: ImportError: cannot import flash_attn**
|
||||
|
||||
Install with no-build-isolation flag:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or install CUDA toolkit first:
|
||||
```bash
|
||||
conda install cuda -c nvidia
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
**Issue: Slower than expected (no speedup)**
|
||||
|
||||
Flash Attention benefits increase with sequence length:
|
||||
- <512 tokens: Minimal speedup (10-20%)
|
||||
- 512-2K tokens: 2-3x speedup
|
||||
- >2K tokens: 3-4x speedup
|
||||
|
||||
Check sequence length is sufficient.
|
||||
|
||||
**Issue: RuntimeError: CUDA error**
|
||||
|
||||
Verify GPU supports Flash Attention:
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.get_device_capability())
|
||||
# Should be ≥(7, 5) for Turing+
|
||||
```
|
||||
|
||||
Flash Attention requires:
|
||||
- Ampere (A100, A10): ✅ Full support
|
||||
- Turing (T4): ✅ Supported
|
||||
- Volta (V100): ❌ Not supported
|
||||
|
||||
**Issue: Accuracy degradation**
|
||||
|
||||
Check dtype is float16 or bfloat16 (not float32):
|
||||
```python
|
||||
q = q.to(torch.float16) # Or torch.bfloat16
|
||||
```
|
||||
|
||||
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
|
||||
|
||||
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
|
||||
|
||||
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
|
||||
|
||||
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
|
||||
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
|
||||
- **CUDA**: 12.0+ (11.8 minimum)
|
||||
- **PyTorch**: 2.2+ for native support
|
||||
|
||||
**Not supported**: V100 (Volta), CPU inference
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
|
||||
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
|
||||
- Blog: https://tridao.me/blog/2024/flash3/
|
||||
- GitHub: https://github.com/Dao-AILab/flash-attention
|
||||
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
|
||||
|
||||
|
||||
215
skills/mlops/flash-attention/references/benchmarks.md
Normal file
215
skills/mlops/flash-attention/references/benchmarks.md
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
# Performance Benchmarks
|
||||
|
||||
## Contents
|
||||
- Speed comparisons across GPUs
|
||||
- Memory usage analysis
|
||||
- Scaling with sequence length
|
||||
- Training vs inference performance
|
||||
- Flash Attention versions comparison
|
||||
|
||||
## Speed comparisons across GPUs
|
||||
|
||||
### A100 80GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|
||||
|------------|----------|--------------|--------------|---------------|
|
||||
| 512 | 1.2 | 0.9 | N/A | 1.3x |
|
||||
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
|
||||
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
|
||||
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
|
||||
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
|
||||
|
||||
### H100 80GB (Hopper)
|
||||
|
||||
**Forward pass time** (milliseconds, same config):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|
||||
|------------|----------|--------------|---------------------|--------------------|--------------|
|
||||
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
|
||||
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
|
||||
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
|
||||
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
|
||||
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
|
||||
|
||||
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
|
||||
|
||||
### A10G 24GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=4):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Speedup |
|
||||
|------------|----------|--------------|---------|
|
||||
| 512 | 2.1 | 1.6 | 1.3x |
|
||||
| 1024 | 6.8 | 2.8 | 2.4x |
|
||||
| 2048 | 25.9 | 9.4 | 2.8x |
|
||||
| 4096 | 102.1 | 35.2 | 2.9x |
|
||||
|
||||
## Memory usage analysis
|
||||
|
||||
### GPU memory consumption (batch=8, heads=32, dim=64)
|
||||
|
||||
**Standard attention memory**:
|
||||
|
||||
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|
||||
|------------|------------------|----------|-------|-------|
|
||||
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
|
||||
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
|
||||
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
|
||||
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
|
||||
|
||||
**Flash Attention 2 memory**:
|
||||
|
||||
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|
||||
|------------|---------------------|----------|-------|-----------|
|
||||
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
|
||||
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
|
||||
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
|
||||
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
|
||||
|
||||
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
|
||||
|
||||
### Memory scaling comparison
|
||||
|
||||
**Llama 2 7B model memory** (float16, batch=1):
|
||||
|
||||
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|
||||
|----------------|-------------------|-------------------|-------------------|
|
||||
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
|
||||
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
|
||||
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
|
||||
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
|
||||
| 32K | OOM | 14.2 GB | Only Flash: Yes |
|
||||
|
||||
### Training memory (Llama 2 7B, batch=4)
|
||||
|
||||
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|
||||
|---------|---------------|-----------------|-----------|
|
||||
| 2K | 18.2 | 12.4 | 32% |
|
||||
| 4K | 34.8 | 16.8 | 52% |
|
||||
| 8K | OOM (>40GB) | 26.2 | Fits! |
|
||||
|
||||
## Scaling with sequence length
|
||||
|
||||
### Computational complexity
|
||||
|
||||
**Standard attention**:
|
||||
- Time: O(N² × d)
|
||||
- Memory: O(N² + N × d)
|
||||
|
||||
**Flash Attention**:
|
||||
- Time: O(N² × d) (same, but with better constants)
|
||||
- Memory: O(N × d) (linear!)
|
||||
|
||||
### Empirical scaling (A100, batch=1, heads=32, dim=64)
|
||||
|
||||
**Time per token (milliseconds)**:
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
|
||||
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
|
||||
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
|
||||
|
||||
**Observation**: Speedup increases quadratically with sequence length!
|
||||
|
||||
### Memory per token (MB)
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
|
||||
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
|
||||
|
||||
**Observation**: Flash Attention memory per token is constant!
|
||||
|
||||
## Training vs inference performance
|
||||
|
||||
### Training (forward + backward, Llama 2 7B, A100)
|
||||
|
||||
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------------|------------------------|--------------------------|---------|
|
||||
| 4 × 2K | 1.2 | 3.1 | 2.6x |
|
||||
| 8 × 2K | 2.1 | 5.8 | 2.8x |
|
||||
| 4 × 4K | 0.4 | 1.3 | 3.3x |
|
||||
| 8 × 4K | OOM | 2.4 | Enabled |
|
||||
| 2 × 8K | 0.1 | 0.4 | 4.0x |
|
||||
|
||||
### Inference (generation, Llama 2 7B, A100)
|
||||
|
||||
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|----------------|----------------------|-------------------------|---------|
|
||||
| 512 | 48 | 52 | 1.1x |
|
||||
| 2K | 42 | 62 | 1.5x |
|
||||
| 4K | 31 | 58 | 1.9x |
|
||||
| 8K | 18 | 51 | 2.8x |
|
||||
| 16K | OOM | 42 | Enabled |
|
||||
|
||||
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
|
||||
|
||||
## Flash Attention versions comparison
|
||||
|
||||
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
|
||||
|
||||
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|
||||
|--------|-----|-----|------------|-----------|
|
||||
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
|
||||
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
|
||||
| TFLOPS | 180 | 420 | 740 | 1150 |
|
||||
| GPU util % | 35% | 55% | 75% | 82% |
|
||||
|
||||
**Key improvements**:
|
||||
- FA2: 2.3x faster than FA1 (better parallelism)
|
||||
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
|
||||
- FA3 (FP8): 2.6x faster than FA2 (low precision)
|
||||
|
||||
### Features by version
|
||||
|
||||
| Feature | FA1 | FA2 | FA3 |
|
||||
|---------|-----|-----|-----|
|
||||
| Basic attention | ✅ | ✅ | ✅ |
|
||||
| Causal masking | ✅ | ✅ | ✅ |
|
||||
| Multi-query attention | ❌ | ✅ | ✅ |
|
||||
| Sliding window | ❌ | ✅ | ✅ |
|
||||
| Paged KV cache | ❌ | ✅ | ✅ |
|
||||
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
|
||||
| Work partitioning | Basic | Advanced | Optimal |
|
||||
|
||||
## Real-world model benchmarks
|
||||
|
||||
### Llama 2 models (A100 80GB, batch=4, seq=2048)
|
||||
|
||||
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------|--------|------------------------|--------------------------|---------|
|
||||
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
|
||||
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
|
||||
|
||||
### GPT-style models (seq=1024)
|
||||
|
||||
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|-------|----------------------|-------------------------|---------|
|
||||
| GPT-2 (124M) | 520 | 680 | 1.3x |
|
||||
| GPT-J (6B) | 42 | 98 | 2.3x |
|
||||
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
|
||||
|
||||
## Recommendations by use case
|
||||
|
||||
**Training large models (>7B parameters)**:
|
||||
- Use Flash Attention 2 on A100
|
||||
- Use Flash Attention 3 FP8 on H100 for maximum speed
|
||||
- Expected: 2.5-3x speedup
|
||||
|
||||
**Long context inference (>4K tokens)**:
|
||||
- Flash Attention essential (enables contexts standard attention can't handle)
|
||||
- Expected: 2-4x speedup, 5-10x memory reduction
|
||||
|
||||
**Short sequences (<512 tokens)**:
|
||||
- Flash Attention provides 1.2-1.5x speedup
|
||||
- Minimal memory benefit
|
||||
- Still worth enabling (no downside)
|
||||
|
||||
**Multi-user serving**:
|
||||
- Flash Attention reduces per-request memory
|
||||
- Allows higher concurrent batch sizes
|
||||
- Can serve 2-3x more users on same hardware
|
||||
|
|
@ -0,0 +1,293 @@
|
|||
# HuggingFace Transformers Integration
|
||||
|
||||
## Contents
|
||||
- Enabling Flash Attention in Transformers
|
||||
- Supported model architectures
|
||||
- Configuration examples
|
||||
- Performance comparisons
|
||||
- Troubleshooting model-specific issues
|
||||
|
||||
## Enabling Flash Attention in Transformers
|
||||
|
||||
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
|
||||
|
||||
**Simple enable for any supported model**:
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
**Install requirements**:
|
||||
```bash
|
||||
pip install transformers>=4.36
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
## Supported model architectures
|
||||
|
||||
As of Transformers 4.40:
|
||||
|
||||
**Fully supported**:
|
||||
- Llama / Llama 2 / Llama 3
|
||||
- Mistral / Mixtral
|
||||
- Falcon
|
||||
- GPT-NeoX
|
||||
- Phi / Phi-2 / Phi-3
|
||||
- Qwen / Qwen2
|
||||
- Gemma
|
||||
- Starcoder2
|
||||
- GPT-J
|
||||
- OPT
|
||||
- BLOOM
|
||||
|
||||
**Partially supported** (encoder-decoder):
|
||||
- BART
|
||||
- T5 / Flan-T5
|
||||
- Whisper
|
||||
|
||||
**Check support**:
|
||||
```python
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.from_pretrained("model-name")
|
||||
print(config._attn_implementation_internal)
|
||||
# 'flash_attention_2' if supported
|
||||
```
|
||||
|
||||
## Configuration examples
|
||||
|
||||
### Llama 2 with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Generate
|
||||
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_length=100)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
### Mistral with Flash Attention for long context
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for long context
|
||||
device_map="auto",
|
||||
max_position_embeddings=32768 # Extended context
|
||||
)
|
||||
|
||||
# Process long document (32K tokens)
|
||||
long_text = "..." * 10000
|
||||
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
```
|
||||
|
||||
### Fine-tuning with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=3,
|
||||
fp16=True, # Must match model dtype
|
||||
optim="adamw_torch_fused" # Fast optimizer
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Multi-GPU training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
# Model parallelism with Flash Attention
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto", # Automatic multi-GPU placement
|
||||
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
|
||||
)
|
||||
```
|
||||
|
||||
## Performance comparisons
|
||||
|
||||
### Memory usage (Llama 2 7B, batch=1)
|
||||
|
||||
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|
||||
|-----------------|-------------------|-------------------|-----------|
|
||||
| 512 | 1.2 GB | 0.9 GB | 25% |
|
||||
| 2048 | 3.8 GB | 1.4 GB | 63% |
|
||||
| 8192 | 14.2 GB | 3.2 GB | 77% |
|
||||
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
|
||||
|
||||
### Speed (tokens/sec, A100 80GB)
|
||||
|
||||
| Model | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|----------|--------------|---------|
|
||||
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
|
||||
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
|
||||
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
|
||||
|
||||
### Training throughput (samples/sec)
|
||||
|
||||
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|------------|----------|--------------|---------|
|
||||
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
|
||||
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
|
||||
|
||||
## Troubleshooting model-specific issues
|
||||
|
||||
### Issue: Model doesn't support Flash Attention
|
||||
|
||||
Check support list above. If not supported, use PyTorch SDPA as fallback:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="sdpa", # PyTorch native (still faster)
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: CUDA out of memory during loading
|
||||
|
||||
Reduce memory footprint:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
max_memory={0: "18GB"}, # Reserve memory for KV cache
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slower inference than expected
|
||||
|
||||
Ensure dtype matches:
|
||||
|
||||
```python
|
||||
# Model and inputs must both be float16/bfloat16
|
||||
model = model.to(torch.float16)
|
||||
inputs = tokenizer(..., return_tensors="pt").to("cuda")
|
||||
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
|
||||
for k, v in inputs.items()}
|
||||
```
|
||||
|
||||
### Issue: Different outputs vs standard attention
|
||||
|
||||
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
|
||||
model_flash = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
out_standard = model_standard(**inputs).logits
|
||||
out_flash = model_flash(**inputs).logits
|
||||
|
||||
diff = (out_standard - out_flash).abs().max()
|
||||
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
|
||||
```
|
||||
|
||||
### Issue: ImportError during model loading
|
||||
|
||||
Install flash-attn:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or disable Flash Attention:
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="eager", # Standard PyTorch
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Always use float16/bfloat16** with Flash Attention (not float32)
|
||||
2. **Set device_map="auto"** for automatic memory management
|
||||
3. **Use bfloat16 for long context** (better numerical stability)
|
||||
4. **Enable gradient checkpointing** for training large models
|
||||
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
|
||||
|
||||
**Example with all best practices**:
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for training
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Enable gradient checkpointing for memory
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Training with optimizations
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=2,
|
||||
bf16=True, # Match model dtype
|
||||
optim="adamw_torch_fused",
|
||||
gradient_checkpointing=True
|
||||
)
|
||||
```
|
||||
430
skills/mlops/gguf/SKILL.md
Normal file
430
skills/mlops/gguf/SKILL.md
Normal file
|
|
@ -0,0 +1,430 @@
|
|||
---
|
||||
name: gguf-quantization
|
||||
description: GGUF format and llama.cpp quantization for efficient CPU/GPU inference. Use when deploying models on consumer hardware, Apple Silicon, or when needing flexible quantization from 2-8 bit without GPU requirements.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [llama-cpp-python>=0.2.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [GGUF, Quantization, llama.cpp, CPU Inference, Apple Silicon, Model Compression, Optimization]
|
||||
|
||||
---
|
||||
|
||||
# GGUF - Quantization Format for llama.cpp
|
||||
|
||||
The GGUF (GPT-Generated Unified Format) is the standard file format for llama.cpp, enabling efficient inference on CPUs, Apple Silicon, and GPUs with flexible quantization options.
|
||||
|
||||
## When to use GGUF
|
||||
|
||||
**Use GGUF when:**
|
||||
- Deploying on consumer hardware (laptops, desktops)
|
||||
- Running on Apple Silicon (M1/M2/M3) with Metal acceleration
|
||||
- Need CPU inference without GPU requirements
|
||||
- Want flexible quantization (Q2_K to Q8_0)
|
||||
- Using local AI tools (LM Studio, Ollama, text-generation-webui)
|
||||
|
||||
**Key advantages:**
|
||||
- **Universal hardware**: CPU, Apple Silicon, NVIDIA, AMD support
|
||||
- **No Python runtime**: Pure C/C++ inference
|
||||
- **Flexible quantization**: 2-8 bit with various methods (K-quants)
|
||||
- **Ecosystem support**: LM Studio, Ollama, koboldcpp, and more
|
||||
- **imatrix**: Importance matrix for better low-bit quality
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **AWQ/GPTQ**: Maximum accuracy with calibration on NVIDIA GPUs
|
||||
- **HQQ**: Fast calibration-free quantization for HuggingFace
|
||||
- **bitsandbytes**: Simple integration with transformers library
|
||||
- **TensorRT-LLM**: Production NVIDIA deployment with maximum speed
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone llama.cpp
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
|
||||
# Build (CPU)
|
||||
make
|
||||
|
||||
# Build with CUDA (NVIDIA)
|
||||
make GGML_CUDA=1
|
||||
|
||||
# Build with Metal (Apple Silicon)
|
||||
make GGML_METAL=1
|
||||
|
||||
# Install Python bindings (optional)
|
||||
pip install llama-cpp-python
|
||||
```
|
||||
|
||||
### Convert model to GGUF
|
||||
|
||||
```bash
|
||||
# Install requirements
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Convert HuggingFace model to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./path/to/model --outfile model-f16.gguf
|
||||
|
||||
# Or specify output type
|
||||
python convert_hf_to_gguf.py ./path/to/model \
|
||||
--outfile model-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
### Quantize model
|
||||
|
||||
```bash
|
||||
# Basic quantization to Q4_K_M
|
||||
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# Quantize with importance matrix (better quality)
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Run inference
|
||||
|
||||
```bash
|
||||
# CLI inference
|
||||
./llama-cli -m model-q4_k_m.gguf -p "Hello, how are you?"
|
||||
|
||||
# Interactive mode
|
||||
./llama-cli -m model-q4_k_m.gguf --interactive
|
||||
|
||||
# With GPU offload
|
||||
./llama-cli -m model-q4_k_m.gguf -ngl 35 -p "Hello!"
|
||||
```
|
||||
|
||||
## Quantization types
|
||||
|
||||
### K-quant methods (recommended)
|
||||
|
||||
| Type | Bits | Size (7B) | Quality | Use Case |
|
||||
|------|------|-----------|---------|----------|
|
||||
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression |
|
||||
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
|
||||
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Balance |
|
||||
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Good balance |
|
||||
| Q4_K_M | 4.5 | ~4.1 GB | High | **Recommended default** |
|
||||
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
|
||||
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
|
||||
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
|
||||
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality |
|
||||
|
||||
### Legacy methods
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| Q4_0 | 4-bit, basic |
|
||||
| Q4_1 | 4-bit with delta |
|
||||
| Q5_0 | 5-bit, basic |
|
||||
| Q5_1 | 5-bit with delta |
|
||||
|
||||
**Recommendation**: Use K-quant methods (Q4_K_M, Q5_K_M) for best quality/size ratio.
|
||||
|
||||
## Conversion workflows
|
||||
|
||||
### Workflow 1: HuggingFace to GGUF
|
||||
|
||||
```bash
|
||||
# 1. Download model
|
||||
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
|
||||
|
||||
# 2. Convert to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./llama-3.1-8b \
|
||||
--outfile llama-3.1-8b-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# 3. Quantize
|
||||
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# 4. Test
|
||||
./llama-cli -m llama-3.1-8b-q4_k_m.gguf -p "Hello!" -n 50
|
||||
```
|
||||
|
||||
### Workflow 2: With importance matrix (better quality)
|
||||
|
||||
```bash
|
||||
# 1. Convert to GGUF
|
||||
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
|
||||
|
||||
# 2. Create calibration text (diverse samples)
|
||||
cat > calibration.txt << 'EOF'
|
||||
The quick brown fox jumps over the lazy dog.
|
||||
Machine learning is a subset of artificial intelligence.
|
||||
Python is a popular programming language.
|
||||
# Add more diverse text samples...
|
||||
EOF
|
||||
|
||||
# 3. Generate importance matrix
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f calibration.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix \
|
||||
-ngl 35 # GPU layers if available
|
||||
|
||||
# 4. Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf \
|
||||
model-q4_k_m.gguf \
|
||||
Q4_K_M
|
||||
```
|
||||
|
||||
### Workflow 3: Multiple quantizations
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
MODEL="llama-3.1-8b-f16.gguf"
|
||||
IMATRIX="llama-3.1-8b.imatrix"
|
||||
|
||||
# Generate imatrix once
|
||||
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
|
||||
|
||||
# Create multiple quantizations
|
||||
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
|
||||
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
|
||||
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
|
||||
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
|
||||
done
|
||||
```
|
||||
|
||||
## Python usage
|
||||
|
||||
### llama-cpp-python
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Load model
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096, # Context window
|
||||
n_gpu_layers=35, # GPU offload (0 for CPU only)
|
||||
n_threads=8 # CPU threads
|
||||
)
|
||||
|
||||
# Generate
|
||||
output = llm(
|
||||
"What is machine learning?",
|
||||
max_tokens=256,
|
||||
temperature=0.7,
|
||||
stop=["</s>", "\n\n"]
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Chat completion
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
chat_format="llama-3" # Or "chatml", "mistral", etc.
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"}
|
||||
]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=messages,
|
||||
max_tokens=256,
|
||||
temperature=0.7
|
||||
)
|
||||
print(response["choices"][0]["message"]["content"])
|
||||
```
|
||||
|
||||
### Streaming
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(model_path="./model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
# Stream tokens
|
||||
for chunk in llm(
|
||||
"Explain quantum computing:",
|
||||
max_tokens=256,
|
||||
stream=True
|
||||
):
|
||||
print(chunk["choices"][0]["text"], end="", flush=True)
|
||||
```
|
||||
|
||||
## Server mode
|
||||
|
||||
### Start OpenAI-compatible server
|
||||
|
||||
```bash
|
||||
# Start server
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096
|
||||
|
||||
# Or with Python bindings
|
||||
python -m llama_cpp.server \
|
||||
--model model-q4_k_m.gguf \
|
||||
--n_gpu_layers 35 \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080
|
||||
```
|
||||
|
||||
### Use with OpenAI client
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=256
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Hardware optimization
|
||||
|
||||
### Apple Silicon (Metal)
|
||||
|
||||
```bash
|
||||
# Build with Metal
|
||||
make clean && make GGML_METAL=1
|
||||
|
||||
# Run with Metal acceleration
|
||||
./llama-cli -m model.gguf -ngl 99 -p "Hello"
|
||||
|
||||
# Python with Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload all layers
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
### NVIDIA CUDA
|
||||
|
||||
```bash
|
||||
# Build with CUDA
|
||||
make clean && make GGML_CUDA=1
|
||||
|
||||
# Run with CUDA
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
|
||||
# Specify GPU
|
||||
CUDA_VISIBLE_DEVICES=0 ./llama-cli -m model.gguf -ngl 35
|
||||
```
|
||||
|
||||
### CPU optimization
|
||||
|
||||
```bash
|
||||
# Build with AVX2/AVX512
|
||||
make clean && make
|
||||
|
||||
# Run with optimal threads
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
|
||||
# Python CPU config
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=0, # CPU only
|
||||
n_threads=8, # Match physical cores
|
||||
n_batch=512 # Batch size for prompt processing
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with tools
|
||||
|
||||
### Ollama
|
||||
|
||||
```bash
|
||||
# Create Modelfile
|
||||
cat > Modelfile << 'EOF'
|
||||
FROM ./model-q4_k_m.gguf
|
||||
TEMPLATE """{{ .System }}
|
||||
{{ .Prompt }}"""
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER num_ctx 4096
|
||||
EOF
|
||||
|
||||
# Create Ollama model
|
||||
ollama create mymodel -f Modelfile
|
||||
|
||||
# Run
|
||||
ollama run mymodel "Hello!"
|
||||
```
|
||||
|
||||
### LM Studio
|
||||
|
||||
1. Place GGUF file in `~/.cache/lm-studio/models/`
|
||||
2. Open LM Studio and select the model
|
||||
3. Configure context length and GPU offload
|
||||
4. Start inference
|
||||
|
||||
### text-generation-webui
|
||||
|
||||
```bash
|
||||
# Place in models folder
|
||||
cp model-q4_k_m.gguf text-generation-webui/models/
|
||||
|
||||
# Start with llama.cpp loader
|
||||
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use K-quants**: Q4_K_M offers best quality/size balance
|
||||
2. **Use imatrix**: Always use importance matrix for Q4 and below
|
||||
3. **GPU offload**: Offload as many layers as VRAM allows
|
||||
4. **Context length**: Start with 4096, increase if needed
|
||||
5. **Thread count**: Match physical CPU cores, not logical
|
||||
6. **Batch size**: Increase n_batch for faster prompt processing
|
||||
|
||||
## Common issues
|
||||
|
||||
**Model loads slowly:**
|
||||
```bash
|
||||
# Use mmap for faster loading
|
||||
./llama-cli -m model.gguf --mmap
|
||||
```
|
||||
|
||||
**Out of memory:**
|
||||
```bash
|
||||
# Reduce GPU layers
|
||||
./llama-cli -m model.gguf -ngl 20 # Reduce from 35
|
||||
|
||||
# Or use smaller quantization
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
**Poor quality at low bits:**
|
||||
```bash
|
||||
# Always use imatrix for Q4 and below
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Batching, speculative decoding, custom builds
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, benchmarks
|
||||
|
||||
## Resources
|
||||
|
||||
- **Repository**: https://github.com/ggml-org/llama.cpp
|
||||
- **Python Bindings**: https://github.com/abetlen/llama-cpp-python
|
||||
- **Pre-quantized Models**: https://huggingface.co/TheBloke
|
||||
- **GGUF Converter**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
|
||||
- **License**: MIT
|
||||
504
skills/mlops/gguf/references/advanced-usage.md
Normal file
504
skills/mlops/gguf/references/advanced-usage.md
Normal file
|
|
@ -0,0 +1,504 @@
|
|||
# GGUF Advanced Usage Guide
|
||||
|
||||
## Speculative Decoding
|
||||
|
||||
### Draft Model Approach
|
||||
|
||||
```bash
|
||||
# Use smaller model as draft for faster generation
|
||||
./llama-speculative \
|
||||
-m large-model-q4_k_m.gguf \
|
||||
-md draft-model-q4_k_m.gguf \
|
||||
-p "Write a story about AI" \
|
||||
-n 500 \
|
||||
--draft 8 # Draft tokens before verification
|
||||
```
|
||||
|
||||
### Self-Speculative Decoding
|
||||
|
||||
```bash
|
||||
# Use same model with different context for speculation
|
||||
./llama-cli -m model-q4_k_m.gguf \
|
||||
--lookup-cache-static lookup.bin \
|
||||
--lookup-cache-dynamic lookup-dynamic.bin \
|
||||
-p "Hello world"
|
||||
```
|
||||
|
||||
## Batched Inference
|
||||
|
||||
### Process Multiple Prompts
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
n_batch=512 # Larger batch for parallel processing
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"What is Python?",
|
||||
"Explain machine learning.",
|
||||
"Describe neural networks."
|
||||
]
|
||||
|
||||
# Process in batch (each prompt gets separate context)
|
||||
for prompt in prompts:
|
||||
output = llm(prompt, max_tokens=100)
|
||||
print(f"Q: {prompt}")
|
||||
print(f"A: {output['choices'][0]['text']}\n")
|
||||
```
|
||||
|
||||
### Server Batching
|
||||
|
||||
```bash
|
||||
# Start server with batching
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096 \
|
||||
--parallel 4 # Concurrent requests
|
||||
--cont-batching # Continuous batching
|
||||
```
|
||||
|
||||
## Custom Model Conversion
|
||||
|
||||
### Convert with Vocabulary Modifications
|
||||
|
||||
```python
|
||||
# custom_convert.py
|
||||
import sys
|
||||
sys.path.insert(0, './llama.cpp')
|
||||
|
||||
from convert_hf_to_gguf import main
|
||||
from gguf import GGUFWriter
|
||||
|
||||
# Custom conversion with modified vocab
|
||||
def convert_with_custom_vocab(model_path, output_path):
|
||||
# Load and modify tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
# Add special tokens if needed
|
||||
special_tokens = {"additional_special_tokens": ["<|custom|>"]}
|
||||
tokenizer.add_special_tokens(special_tokens)
|
||||
tokenizer.save_pretrained(model_path)
|
||||
|
||||
# Then run standard conversion
|
||||
main([model_path, "--outfile", output_path])
|
||||
```
|
||||
|
||||
### Convert Specific Architecture
|
||||
|
||||
```bash
|
||||
# For Mistral-style models
|
||||
python convert_hf_to_gguf.py ./mistral-model \
|
||||
--outfile mistral-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# For Qwen models
|
||||
python convert_hf_to_gguf.py ./qwen-model \
|
||||
--outfile qwen-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# For Phi models
|
||||
python convert_hf_to_gguf.py ./phi-model \
|
||||
--outfile phi-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
## Advanced Quantization
|
||||
|
||||
### Mixed Quantization
|
||||
|
||||
```bash
|
||||
# Quantize different layer types differently
|
||||
./llama-quantize model-f16.gguf model-mixed.gguf Q4_K_M \
|
||||
--allow-requantize \
|
||||
--leave-output-tensor
|
||||
```
|
||||
|
||||
### Quantization with Token Embeddings
|
||||
|
||||
```bash
|
||||
# Keep embeddings at higher precision
|
||||
./llama-quantize model-f16.gguf model-q4.gguf Q4_K_M \
|
||||
--token-embedding-type f16
|
||||
```
|
||||
|
||||
### IQ Quantization (Importance-aware)
|
||||
|
||||
```bash
|
||||
# Ultra-low bit quantization with importance
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-iq2_xxs.gguf IQ2_XXS
|
||||
|
||||
# Available IQ types: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Memory Mapping
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Use memory mapping for large models
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
use_mmap=True, # Memory map the model
|
||||
use_mlock=False, # Don't lock in RAM
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
### Partial GPU Offload
|
||||
|
||||
```python
|
||||
# Calculate layers to offload based on VRAM
|
||||
import subprocess
|
||||
|
||||
def get_free_vram_gb():
|
||||
result = subprocess.run(
|
||||
['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return int(result.stdout.strip()) / 1024
|
||||
|
||||
# Estimate layers based on VRAM (rough: 0.5GB per layer for 7B Q4)
|
||||
free_vram = get_free_vram_gb()
|
||||
layers_to_offload = int(free_vram / 0.5)
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_gpu_layers=min(layers_to_offload, 35) # Cap at total layers
|
||||
)
|
||||
```
|
||||
|
||||
### KV Cache Optimization
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Optimize KV cache for long contexts
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=8192, # Large context
|
||||
n_gpu_layers=35,
|
||||
type_k=1, # Q8_0 for K cache (1)
|
||||
type_v=1, # Q8_0 for V cache (1)
|
||||
# Or use Q4_0 (2) for more compression
|
||||
)
|
||||
```
|
||||
|
||||
## Context Management
|
||||
|
||||
### Context Shifting
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
# Handle long conversations with context shifting
|
||||
conversation = []
|
||||
max_history = 10
|
||||
|
||||
def chat(user_message):
|
||||
conversation.append({"role": "user", "content": user_message})
|
||||
|
||||
# Keep only recent history
|
||||
if len(conversation) > max_history * 2:
|
||||
conversation = conversation[-max_history * 2:]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=conversation,
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
assistant_message = response["choices"][0]["message"]["content"]
|
||||
conversation.append({"role": "assistant", "content": assistant_message})
|
||||
return assistant_message
|
||||
```
|
||||
|
||||
### Save and Load State
|
||||
|
||||
```bash
|
||||
# Save state to file
|
||||
./llama-cli -m model.gguf \
|
||||
-p "Once upon a time" \
|
||||
--save-session session.bin \
|
||||
-n 100
|
||||
|
||||
# Load and continue
|
||||
./llama-cli -m model.gguf \
|
||||
--load-session session.bin \
|
||||
-p " and they lived" \
|
||||
-n 100
|
||||
```
|
||||
|
||||
## Grammar Constrained Generation
|
||||
|
||||
### JSON Output
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama, LlamaGrammar
|
||||
|
||||
# Define JSON grammar
|
||||
json_grammar = LlamaGrammar.from_string('''
|
||||
root ::= object
|
||||
object ::= "{" ws pair ("," ws pair)* "}" ws
|
||||
pair ::= string ":" ws value
|
||||
value ::= string | number | object | array | "true" | "false" | "null"
|
||||
array ::= "[" ws value ("," ws value)* "]" ws
|
||||
string ::= "\\"" [^"\\\\]* "\\""
|
||||
number ::= [0-9]+
|
||||
ws ::= [ \\t\\n]*
|
||||
''')
|
||||
|
||||
llm = Llama(model_path="model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
output = llm(
|
||||
"Output a JSON object with name and age:",
|
||||
grammar=json_grammar,
|
||||
max_tokens=100
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Custom Grammar
|
||||
|
||||
```python
|
||||
# Grammar for specific format
|
||||
answer_grammar = LlamaGrammar.from_string('''
|
||||
root ::= "Answer: " letter "\\n" "Explanation: " explanation
|
||||
letter ::= [A-D]
|
||||
explanation ::= [a-zA-Z0-9 .,!?]+
|
||||
''')
|
||||
|
||||
output = llm(
|
||||
"Q: What is 2+2? A) 3 B) 4 C) 5 D) 6",
|
||||
grammar=answer_grammar,
|
||||
max_tokens=100
|
||||
)
|
||||
```
|
||||
|
||||
## LoRA Integration
|
||||
|
||||
### Load LoRA Adapter
|
||||
|
||||
```bash
|
||||
# Apply LoRA at runtime
|
||||
./llama-cli -m base-model-q4_k_m.gguf \
|
||||
--lora lora-adapter.gguf \
|
||||
--lora-scale 1.0 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Multiple LoRA Adapters
|
||||
|
||||
```bash
|
||||
# Stack multiple adapters
|
||||
./llama-cli -m base-model.gguf \
|
||||
--lora adapter1.gguf --lora-scale 0.5 \
|
||||
--lora adapter2.gguf --lora-scale 0.5 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Python LoRA Usage
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="base-model-q4_k_m.gguf",
|
||||
lora_path="lora-adapter.gguf",
|
||||
lora_scale=1.0,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
## Embedding Generation
|
||||
|
||||
### Extract Embeddings
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
embedding=True, # Enable embedding mode
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
embeddings = llm.embed("This is a test sentence.")
|
||||
print(f"Embedding dimension: {len(embeddings)}")
|
||||
```
|
||||
|
||||
### Batch Embeddings
|
||||
|
||||
```python
|
||||
texts = [
|
||||
"Machine learning is fascinating.",
|
||||
"Deep learning uses neural networks.",
|
||||
"Python is a programming language."
|
||||
]
|
||||
|
||||
embeddings = [llm.embed(text) for text in texts]
|
||||
|
||||
# Calculate similarity
|
||||
import numpy as np
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
sim = cosine_similarity(embeddings[0], embeddings[1])
|
||||
print(f"Similarity: {sim:.4f}")
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Benchmark Script
|
||||
|
||||
```python
|
||||
import time
|
||||
from llama_cpp import Llama
|
||||
|
||||
def benchmark(model_path, prompt, n_tokens=100, n_runs=5):
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=35,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Warmup
|
||||
llm(prompt, max_tokens=10)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
start = time.time()
|
||||
output = llm(prompt, max_tokens=n_tokens)
|
||||
elapsed = time.time() - start
|
||||
times.append(elapsed)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
tokens_per_sec = n_tokens / avg_time
|
||||
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Avg time: {avg_time:.2f}s")
|
||||
print(f"Tokens/sec: {tokens_per_sec:.1f}")
|
||||
|
||||
return tokens_per_sec
|
||||
|
||||
# Compare quantizations
|
||||
for quant in ["q4_k_m", "q5_k_m", "q8_0"]:
|
||||
benchmark(f"model-{quant}.gguf", "Explain quantum computing:", 100)
|
||||
```
|
||||
|
||||
### Optimal Configuration Finder
|
||||
|
||||
```python
|
||||
def find_optimal_config(model_path, target_vram_gb=8):
|
||||
"""Find optimal n_gpu_layers and n_batch for target VRAM."""
|
||||
from llama_cpp import Llama
|
||||
import gc
|
||||
|
||||
best_config = None
|
||||
best_speed = 0
|
||||
|
||||
for n_gpu_layers in range(0, 50, 5):
|
||||
for n_batch in [128, 256, 512, 1024]:
|
||||
try:
|
||||
gc.collect()
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=n_gpu_layers,
|
||||
n_batch=n_batch,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Quick benchmark
|
||||
start = time.time()
|
||||
llm("Hello", max_tokens=50)
|
||||
speed = 50 / (time.time() - start)
|
||||
|
||||
if speed > best_speed:
|
||||
best_speed = speed
|
||||
best_config = {
|
||||
"n_gpu_layers": n_gpu_layers,
|
||||
"n_batch": n_batch,
|
||||
"speed": speed
|
||||
}
|
||||
|
||||
del llm
|
||||
gc.collect()
|
||||
|
||||
except Exception as e:
|
||||
print(f"OOM at layers={n_gpu_layers}, batch={n_batch}")
|
||||
break
|
||||
|
||||
return best_config
|
||||
```
|
||||
|
||||
## Multi-GPU Setup
|
||||
|
||||
### Distribute Across GPUs
|
||||
|
||||
```bash
|
||||
# Split model across multiple GPUs
|
||||
./llama-cli -m large-model.gguf \
|
||||
--tensor-split 0.5,0.5 \
|
||||
-ngl 60 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Python Multi-GPU
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="large-model-q4_k_m.gguf",
|
||||
n_gpu_layers=60,
|
||||
tensor_split=[0.5, 0.5] # Split evenly across 2 GPUs
|
||||
)
|
||||
```
|
||||
|
||||
## Custom Builds
|
||||
|
||||
### Build with All Optimizations
|
||||
|
||||
```bash
|
||||
# Clean build with all CPU optimizations
|
||||
make clean
|
||||
LLAMA_OPENBLAS=1 LLAMA_BLAS_VENDOR=OpenBLAS make -j
|
||||
|
||||
# With CUDA and cuBLAS
|
||||
make clean
|
||||
GGML_CUDA=1 LLAMA_CUBLAS=1 make -j
|
||||
|
||||
# With specific CUDA architecture
|
||||
GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_86 make -j
|
||||
```
|
||||
|
||||
### CMake Build
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
cmake .. -DGGML_CUDA=ON -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build . --config Release -j
|
||||
```
|
||||
442
skills/mlops/gguf/references/troubleshooting.md
Normal file
442
skills/mlops/gguf/references/troubleshooting.md
Normal file
|
|
@ -0,0 +1,442 @@
|
|||
# GGUF Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Build Fails
|
||||
|
||||
**Error**: `make: *** No targets specified and no makefile found`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Ensure you're in llama.cpp directory
|
||||
cd llama.cpp
|
||||
make
|
||||
```
|
||||
|
||||
**Error**: `fatal error: cuda_runtime.h: No such file or directory`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install CUDA toolkit
|
||||
# Ubuntu
|
||||
sudo apt install nvidia-cuda-toolkit
|
||||
|
||||
# Or set CUDA path
|
||||
export CUDA_PATH=/usr/local/cuda
|
||||
export PATH=$CUDA_PATH/bin:$PATH
|
||||
make GGML_CUDA=1
|
||||
```
|
||||
|
||||
### Python Bindings Issues
|
||||
|
||||
**Error**: `ERROR: Failed building wheel for llama-cpp-python`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install build dependencies
|
||||
pip install cmake scikit-build-core
|
||||
|
||||
# For CUDA support
|
||||
CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
|
||||
# For Metal (macOS)
|
||||
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
```
|
||||
|
||||
**Error**: `ImportError: libcudart.so.XX: cannot open shared object file`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Add CUDA libraries to path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Or reinstall with correct CUDA version
|
||||
pip uninstall llama-cpp-python
|
||||
CUDACXX=/usr/local/cuda/bin/nvcc CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python
|
||||
```
|
||||
|
||||
## Conversion Issues
|
||||
|
||||
### Model Not Supported
|
||||
|
||||
**Error**: `KeyError: 'model.embed_tokens.weight'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check model architecture
|
||||
python -c "from transformers import AutoConfig; print(AutoConfig.from_pretrained('./model').architectures)"
|
||||
|
||||
# Use appropriate conversion script
|
||||
# For most models:
|
||||
python convert_hf_to_gguf.py ./model --outfile model.gguf
|
||||
|
||||
# For older models, check if legacy script needed
|
||||
```
|
||||
|
||||
### Vocabulary Mismatch
|
||||
|
||||
**Error**: `RuntimeError: Vocabulary size mismatch`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure tokenizer matches model
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("./model")
|
||||
model = AutoModelForCausalLM.from_pretrained("./model")
|
||||
|
||||
print(f"Tokenizer vocab size: {len(tokenizer)}")
|
||||
print(f"Model vocab size: {model.config.vocab_size}")
|
||||
|
||||
# If mismatch, resize embeddings before conversion
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.save_pretrained("./model-fixed")
|
||||
```
|
||||
|
||||
### Out of Memory During Conversion
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during conversion
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Use CPU for conversion
|
||||
CUDA_VISIBLE_DEVICES="" python convert_hf_to_gguf.py ./model --outfile model.gguf
|
||||
|
||||
# Or use low memory mode
|
||||
python convert_hf_to_gguf.py ./model --outfile model.gguf --outtype f16
|
||||
```
|
||||
|
||||
## Quantization Issues
|
||||
|
||||
### Wrong Output File Size
|
||||
|
||||
**Problem**: Quantized file is larger than expected
|
||||
|
||||
**Check**:
|
||||
```bash
|
||||
# Verify quantization type
|
||||
./llama-cli -m model.gguf --verbose
|
||||
|
||||
# Expected sizes for 7B model:
|
||||
# Q4_K_M: ~4.1 GB
|
||||
# Q5_K_M: ~4.8 GB
|
||||
# Q8_0: ~7.2 GB
|
||||
# F16: ~13.5 GB
|
||||
```
|
||||
|
||||
### Quantization Crashes
|
||||
|
||||
**Error**: `Segmentation fault` during quantization
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Increase stack size
|
||||
ulimit -s unlimited
|
||||
|
||||
# Or use less threads
|
||||
./llama-quantize -t 4 model-f16.gguf model-q4.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Poor Quality After Quantization
|
||||
|
||||
**Problem**: Model outputs gibberish after quantization
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Use importance matrix**:
|
||||
```bash
|
||||
# Generate imatrix with good calibration data
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f wiki_sample.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix
|
||||
|
||||
# Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
2. **Try higher precision**:
|
||||
```bash
|
||||
# Use Q5_K_M or Q6_K instead of Q4
|
||||
./llama-quantize model-f16.gguf model-q5_k_m.gguf Q5_K_M
|
||||
```
|
||||
|
||||
3. **Check original model**:
|
||||
```bash
|
||||
# Test FP16 version first
|
||||
./llama-cli -m model-f16.gguf -p "Hello, how are you?" -n 50
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Slow Generation
|
||||
|
||||
**Problem**: Generation is slower than expected
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable GPU offload**:
|
||||
```bash
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
```
|
||||
|
||||
2. **Optimize batch size**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_batch=512, # Increase for faster prompt processing
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
3. **Use appropriate threads**:
|
||||
```bash
|
||||
# Match physical cores, not logical
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
```
|
||||
|
||||
4. **Enable Flash Attention** (if supported):
|
||||
```bash
|
||||
./llama-cli -m model.gguf -ngl 35 --flash-attn -p "Hello"
|
||||
```
|
||||
|
||||
### Out of Memory
|
||||
|
||||
**Error**: `CUDA out of memory` or system freeze
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce GPU layers**:
|
||||
```python
|
||||
# Start low and increase
|
||||
llm = Llama(model_path="model.gguf", n_gpu_layers=10)
|
||||
```
|
||||
|
||||
2. **Use smaller quantization**:
|
||||
```bash
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
3. **Reduce context length**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_ctx=2048, # Reduce from 4096
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
4. **Quantize KV cache**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
type_k=2, # Q4_0 for K cache
|
||||
type_v=2, # Q4_0 for V cache
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
### Garbage Output
|
||||
|
||||
**Problem**: Model outputs random characters or nonsense
|
||||
|
||||
**Diagnose**:
|
||||
```python
|
||||
# Check model loading
|
||||
llm = Llama(model_path="model.gguf", verbose=True)
|
||||
|
||||
# Test with simple prompt
|
||||
output = llm("1+1=", max_tokens=5, temperature=0)
|
||||
print(output)
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check model integrity**:
|
||||
```bash
|
||||
# Verify GGUF file
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | head -50
|
||||
```
|
||||
|
||||
2. **Use correct chat format**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
chat_format="llama-3" # Match your model: chatml, mistral, etc.
|
||||
)
|
||||
```
|
||||
|
||||
3. **Check temperature**:
|
||||
```python
|
||||
# Use lower temperature for deterministic output
|
||||
output = llm("Hello", max_tokens=50, temperature=0.1)
|
||||
```
|
||||
|
||||
### Token Issues
|
||||
|
||||
**Error**: `RuntimeError: unknown token` or encoding errors
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure UTF-8 encoding
|
||||
prompt = "Hello, world!".encode('utf-8').decode('utf-8')
|
||||
output = llm(prompt, max_tokens=50)
|
||||
```
|
||||
|
||||
## Server Issues
|
||||
|
||||
### Connection Refused
|
||||
|
||||
**Error**: `Connection refused` when accessing server
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Bind to all interfaces
|
||||
./llama-server -m model.gguf --host 0.0.0.0 --port 8080
|
||||
|
||||
# Check if port is in use
|
||||
lsof -i :8080
|
||||
```
|
||||
|
||||
### Server Crashes Under Load
|
||||
|
||||
**Problem**: Server crashes with multiple concurrent requests
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Limit parallelism**:
|
||||
```bash
|
||||
./llama-server -m model.gguf \
|
||||
--parallel 2 \
|
||||
-c 4096 \
|
||||
--cont-batching
|
||||
```
|
||||
|
||||
2. **Add request timeout**:
|
||||
```bash
|
||||
./llama-server -m model.gguf --timeout 300
|
||||
```
|
||||
|
||||
3. **Monitor memory**:
|
||||
```bash
|
||||
watch -n 1 nvidia-smi # For GPU
|
||||
watch -n 1 free -h # For RAM
|
||||
```
|
||||
|
||||
### API Compatibility Issues
|
||||
|
||||
**Problem**: OpenAI client not working with server
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Use correct base URL format
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1", # Include /v1
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
# Use correct model name
|
||||
response = client.chat.completions.create(
|
||||
model="local", # Or the actual model name
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
```
|
||||
|
||||
## Apple Silicon Issues
|
||||
|
||||
### Metal Not Working
|
||||
|
||||
**Problem**: Metal acceleration not enabled
|
||||
|
||||
**Check**:
|
||||
```bash
|
||||
# Verify Metal support
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | grep -i metal
|
||||
```
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Rebuild with Metal
|
||||
make clean
|
||||
make GGML_METAL=1
|
||||
|
||||
# Python bindings
|
||||
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall
|
||||
```
|
||||
|
||||
### Incorrect Memory Usage on M1/M2
|
||||
|
||||
**Problem**: Model uses too much unified memory
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Offload all layers for Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload everything
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
### Enable Verbose Output
|
||||
|
||||
```bash
|
||||
# CLI verbose mode
|
||||
./llama-cli -m model.gguf --verbose -p "Hello" -n 50
|
||||
|
||||
# Python verbose
|
||||
llm = Llama(model_path="model.gguf", verbose=True)
|
||||
```
|
||||
|
||||
### Check Model Metadata
|
||||
|
||||
```bash
|
||||
# View GGUF metadata
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | head -100
|
||||
```
|
||||
|
||||
### Validate GGUF File
|
||||
|
||||
```python
|
||||
import struct
|
||||
|
||||
def validate_gguf(filepath):
|
||||
with open(filepath, 'rb') as f:
|
||||
magic = f.read(4)
|
||||
if magic != b'GGUF':
|
||||
print(f"Invalid magic: {magic}")
|
||||
return False
|
||||
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
print(f"GGUF version: {version}")
|
||||
|
||||
tensor_count = struct.unpack('<Q', f.read(8))[0]
|
||||
metadata_count = struct.unpack('<Q', f.read(8))[0]
|
||||
print(f"Tensors: {tensor_count}, Metadata: {metadata_count}")
|
||||
|
||||
return True
|
||||
|
||||
validate_gguf("model.gguf")
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/ggml-org/llama.cpp/issues
|
||||
2. **Discussions**: https://github.com/ggml-org/llama.cpp/discussions
|
||||
3. **Reddit**: r/LocalLLaMA
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- llama.cpp version/commit hash
|
||||
- Build command used
|
||||
- Model name and quantization
|
||||
- Full error message/stack trace
|
||||
- Hardware: CPU/GPU model, RAM, VRAM
|
||||
- OS version
|
||||
- Minimal reproduction steps
|
||||
97
skills/mlops/grpo-rl-training/README.md
Normal file
97
skills/mlops/grpo-rl-training/README.md
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
# GRPO/RL Training Skill
|
||||
|
||||
**Expert-level guidance for Group Relative Policy Optimization with TRL**
|
||||
|
||||
## 📁 Skill Structure
|
||||
|
||||
```
|
||||
grpo-rl-training/
|
||||
├── SKILL.md # Main skill documentation (READ THIS FIRST)
|
||||
├── README.md # This file
|
||||
├── templates/
|
||||
│ └── basic_grpo_training.py # Production-ready training template
|
||||
└── examples/
|
||||
└── reward_functions_library.py # 20+ reward function examples
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
|
||||
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
|
||||
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
|
||||
4. **Modify for your use case** - Adapt dataset, rewards, and config
|
||||
|
||||
## 💡 What's Inside
|
||||
|
||||
### SKILL.md (Main Documentation)
|
||||
- Core GRPO concepts and algorithm fundamentals
|
||||
- Complete implementation workflow (dataset → rewards → training → deployment)
|
||||
- 10+ reward function examples with code
|
||||
- Hyperparameter tuning guide
|
||||
- Training insights (loss behavior, metrics, debugging)
|
||||
- Troubleshooting guide
|
||||
- Production best practices
|
||||
|
||||
### Templates
|
||||
- **basic_grpo_training.py**: Minimal, production-ready training script
|
||||
- Uses Qwen 2.5 1.5B Instruct
|
||||
- 3 reward functions (format + correctness)
|
||||
- LoRA for efficient training
|
||||
- Fully documented and ready to run
|
||||
|
||||
### Examples
|
||||
- **reward_functions_library.py**: 20+ battle-tested reward functions
|
||||
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
|
||||
- Format rewards (XML, JSON, strict/soft)
|
||||
- Length rewards (ideal length, min/max)
|
||||
- Style rewards (reasoning quality, citations, repetition penalty)
|
||||
- Combined rewards (multi-objective optimization)
|
||||
- Preset collections for common tasks
|
||||
|
||||
## 📖 Usage for Agents
|
||||
|
||||
When this skill is loaded in your agent's context:
|
||||
|
||||
1. **Always read SKILL.md first** before implementing
|
||||
2. **Start simple** - Use length-based reward to validate setup
|
||||
3. **Build incrementally** - Add one reward function at a time
|
||||
4. **Reference examples** - Copy patterns from reward_functions_library.py
|
||||
5. **Monitor training** - Watch reward metrics (not loss!)
|
||||
|
||||
## 🎯 Common Use Cases
|
||||
|
||||
| Task Type | Recommended Rewards | Template |
|
||||
|-----------|---------------------|----------|
|
||||
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
|
||||
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
|
||||
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
|
||||
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
|
||||
|
||||
## ⚠️ Critical Reminders
|
||||
|
||||
- **Loss goes UP during training** - This is normal (it's KL divergence)
|
||||
- **Use 3-5 reward functions** - Single rewards often fail
|
||||
- **Test rewards before training** - Debug each function independently
|
||||
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with num_generations=4-8** - Scale up if GPU allows
|
||||
|
||||
## 🔗 External Resources
|
||||
|
||||
- [TRL Documentation](https://huggingface.co/docs/trl)
|
||||
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
|
||||
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
|
||||
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
|
||||
|
||||
## 📝 Version
|
||||
|
||||
**v1.0.0** - Initial release (January 2025)
|
||||
|
||||
## 👨💻 Maintained By
|
||||
|
||||
Orchestra Research
|
||||
For questions or improvements, see https://orchestra.com
|
||||
|
||||
---
|
||||
|
||||
**License:** MIT
|
||||
**Last Updated:** January 2025
|
||||
575
skills/mlops/grpo-rl-training/SKILL.md
Normal file
575
skills/mlops/grpo-rl-training/SKILL.md
Normal file
|
|
@ -0,0 +1,575 @@
|
|||
---
|
||||
name: grpo-rl-training
|
||||
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
|
||||
|
||||
---
|
||||
|
||||
# GRPO/RL Training with TRL
|
||||
|
||||
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use GRPO training when you need to:
|
||||
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
|
||||
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
|
||||
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
|
||||
- **Align models to domain-specific behaviors** without labeled preference data
|
||||
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
|
||||
|
||||
**Do NOT use GRPO for:**
|
||||
- Simple supervised fine-tuning tasks (use SFT instead)
|
||||
- Tasks without clear reward signals
|
||||
- When you already have high-quality preference pairs (use DPO/PPO instead)
|
||||
|
||||
---
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. GRPO Algorithm Fundamentals
|
||||
|
||||
**Key Mechanism:**
|
||||
- Generates **multiple completions** for each prompt (group size: 4-16)
|
||||
- Compares completions within each group using reward functions
|
||||
- Updates policy to favor higher-rewarded responses relative to the group
|
||||
|
||||
**Critical Difference from PPO:**
|
||||
- No separate reward model needed
|
||||
- More sample-efficient (learns from within-group comparisons)
|
||||
- Simpler to implement and debug
|
||||
|
||||
**Mathematical Intuition:**
|
||||
```
|
||||
For each prompt p:
|
||||
1. Generate N completions: {c₁, c₂, ..., cₙ}
|
||||
2. Compute rewards: {r₁, r₂, ..., rₙ}
|
||||
3. Learn to increase probability of high-reward completions
|
||||
relative to low-reward ones in the same group
|
||||
```
|
||||
|
||||
### 2. Reward Function Design Philosophy
|
||||
|
||||
**Golden Rules:**
|
||||
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** - Higher weight = stronger signal
|
||||
3. **Use incremental rewards** - Partial credit for partial compliance
|
||||
4. **Test rewards independently** - Debug each reward function in isolation
|
||||
|
||||
**Reward Function Types:**
|
||||
|
||||
| Type | Use Case | Example Weight |
|
||||
|------|----------|----------------|
|
||||
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
|
||||
| **Format** | Strict structure enforcement | 0.5-1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
|
||||
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Workflow
|
||||
|
||||
### Step 1: Dataset Preparation
|
||||
|
||||
**Critical Requirements:**
|
||||
- Prompts in chat format (list of dicts with 'role' and 'content')
|
||||
- Include system prompts to set expectations
|
||||
- For verifiable tasks, include ground truth answers as additional columns
|
||||
|
||||
**Example Structure:**
|
||||
```python
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
def prepare_dataset(raw_data):
|
||||
"""
|
||||
Transform raw data into GRPO-compatible format.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content (system + user messages)
|
||||
- 'answer': str (ground truth, optional but recommended)
|
||||
"""
|
||||
return raw_data.map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': extract_answer(x['raw_answer'])
|
||||
})
|
||||
```
|
||||
|
||||
**Pro Tips:**
|
||||
- Use one-shot or few-shot examples in system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256-512 tokens)
|
||||
- Validate data quality before training (garbage in = garbage out)
|
||||
|
||||
### Step 2: Reward Function Implementation
|
||||
|
||||
**Template Structure:**
|
||||
```python
|
||||
def reward_function_name(
|
||||
prompts, # List[List[Dict]]: Original prompts
|
||||
completions, # List[List[Dict]]: Model generations
|
||||
answer=None, # Optional: Ground truth from dataset
|
||||
**kwargs # Additional dataset columns
|
||||
) -> list[float]:
|
||||
"""
|
||||
Evaluate completions and return rewards.
|
||||
|
||||
Returns: List of floats (one per completion)
|
||||
"""
|
||||
# Extract completion text
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
|
||||
# Compute rewards
|
||||
rewards = []
|
||||
for response in responses:
|
||||
score = compute_score(response)
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Example 1: Correctness Reward (Math/Coding)**
|
||||
```python
|
||||
def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
"""Reward correct answers with high score."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_final_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
```
|
||||
|
||||
**Example 2: Format Reward (Structured Output)**
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward XML-like structured format."""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
|
||||
for r in responses]
|
||||
```
|
||||
|
||||
**Example 3: Incremental Format Reward (Partial Credit)**
|
||||
```python
|
||||
def incremental_format_reward(completions, **kwargs):
|
||||
"""Award partial credit for format compliance."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.25
|
||||
if '</reasoning>' in r:
|
||||
score += 0.25
|
||||
if '<answer>' in r:
|
||||
score += 0.25
|
||||
if '</answer>' in r:
|
||||
score += 0.25
|
||||
# Penalize extra text after closing tag
|
||||
if r.count('</answer>') == 1:
|
||||
extra_text = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra_text) * 0.001
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Critical Insight:**
|
||||
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
|
||||
### Step 3: Training Configuration
|
||||
|
||||
**Memory-Optimized Config (Small GPU)**
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6, # Lower = more stable
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4, # Effective batch = 4
|
||||
|
||||
# GRPO-specific
|
||||
num_generations=8, # Group size: 8-16 recommended
|
||||
max_prompt_length=256,
|
||||
max_completion_length=512,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
max_steps=None, # Or set fixed steps (e.g., 500)
|
||||
|
||||
# Optimization
|
||||
bf16=True, # Faster on A100/H100
|
||||
optim="adamw_8bit", # Memory-efficient optimizer
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Or "none" for no logging
|
||||
)
|
||||
```
|
||||
|
||||
**High-Performance Config (Large GPU)**
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
learning_rate=1e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=2,
|
||||
num_generations=16, # Larger groups = better signal
|
||||
max_prompt_length=512,
|
||||
max_completion_length=1024,
|
||||
num_train_epochs=1,
|
||||
bf16=True,
|
||||
use_vllm=True, # Fast generation with vLLM
|
||||
logging_steps=10,
|
||||
)
|
||||
```
|
||||
|
||||
**Critical Hyperparameters:**
|
||||
|
||||
| Parameter | Impact | Tuning Advice |
|
||||
|-----------|--------|---------------|
|
||||
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
|
||||
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
|
||||
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
|
||||
|
||||
### Step 4: Model Setup and Training
|
||||
|
||||
**Standard Setup (Transformers)**
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load model
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # 2-3x faster
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Optional: LoRA for parameter-efficient training
|
||||
peft_config = LoraConfig(
|
||||
r=16, # Rank (higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward,
|
||||
format_reward,
|
||||
correctness_reward,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
# Save
|
||||
trainer.save_model("final_model")
|
||||
```
|
||||
|
||||
**Unsloth Setup (2-3x Faster)**
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="google/gemma-3-1b-it",
|
||||
max_seq_length=1024,
|
||||
load_in_4bit=True,
|
||||
fast_inference=True,
|
||||
max_lora_rank=32,
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha=32,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Rest is identical to standard setup
|
||||
trainer = GRPOTrainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Critical Training Insights
|
||||
|
||||
### 1. Loss Behavior (EXPECTED PATTERN)
|
||||
- **Loss starts near 0 and INCREASES during training**
|
||||
- This is CORRECT - loss measures KL divergence from initial policy
|
||||
- Model is learning (diverging from original behavior to optimize rewards)
|
||||
- Monitor reward metrics instead of loss for progress
|
||||
|
||||
### 2. Reward Tracking
|
||||
Key metrics to watch:
|
||||
- `reward`: Average across all completions
|
||||
- `reward_std`: Diversity within groups (should remain > 0)
|
||||
- `kl`: KL divergence from reference (should grow moderately)
|
||||
|
||||
**Healthy Training Pattern:**
|
||||
```
|
||||
Step Reward Reward_Std KL
|
||||
100 0.5 0.3 0.02
|
||||
200 0.8 0.25 0.05
|
||||
300 1.2 0.2 0.08 ← Good progression
|
||||
400 1.5 0.15 0.12
|
||||
```
|
||||
|
||||
**Warning Signs:**
|
||||
- Reward std → 0 (model collapsing to single response)
|
||||
- KL exploding (> 0.5) (diverging too much, reduce LR)
|
||||
- Reward stuck (reward functions too harsh or model capacity issue)
|
||||
|
||||
### 3. Common Pitfalls and Solutions
|
||||
|
||||
| Problem | Symptom | Solution |
|
||||
|---------|---------|----------|
|
||||
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
|
||||
| **No learning** | Flat rewards | Check reward function logic, increase LR |
|
||||
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
|
||||
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
|
||||
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
|
||||
|
||||
---
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### 1. Multi-Stage Training
|
||||
For complex tasks, train in stages:
|
||||
|
||||
```python
|
||||
# Stage 1: Format compliance (epochs=1)
|
||||
trainer_stage1 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[incremental_format_reward, format_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage1.train()
|
||||
|
||||
# Stage 2: Correctness (epochs=1)
|
||||
trainer_stage2 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[format_reward, correctness_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage2.train()
|
||||
```
|
||||
|
||||
### 2. Adaptive Reward Scaling
|
||||
```python
|
||||
class AdaptiveReward:
|
||||
def __init__(self, base_reward_func, initial_weight=1.0):
|
||||
self.func = base_reward_func
|
||||
self.weight = initial_weight
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
rewards = self.func(*args, **kwargs)
|
||||
return [r * self.weight for r in rewards]
|
||||
|
||||
def adjust_weight(self, success_rate):
|
||||
"""Increase weight if model struggling, decrease if succeeding."""
|
||||
if success_rate < 0.3:
|
||||
self.weight *= 1.2
|
||||
elif success_rate > 0.8:
|
||||
self.weight *= 0.9
|
||||
```
|
||||
|
||||
### 3. Custom Dataset Integration
|
||||
```python
|
||||
def load_custom_knowledge_base(csv_path):
|
||||
"""Example: School communication platform docs."""
|
||||
import pandas as pd
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
dataset = Dataset.from_pandas(df).map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': x['expert_answer']
|
||||
})
|
||||
return dataset
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deployment and Inference
|
||||
|
||||
### Save and Merge LoRA
|
||||
```python
|
||||
# Merge LoRA adapters into base model
|
||||
if hasattr(trainer.model, 'merge_and_unload'):
|
||||
merged_model = trainer.model.merge_and_unload()
|
||||
merged_model.save_pretrained("production_model")
|
||||
tokenizer.save_pretrained("production_model")
|
||||
```
|
||||
|
||||
### Inference Example
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline(
|
||||
"text-generation",
|
||||
model="production_model",
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
result = generator(
|
||||
[
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': "What is 15 + 27?"}
|
||||
],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
)
|
||||
print(result[0]['generated_text'])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices Checklist
|
||||
|
||||
**Before Training:**
|
||||
- [ ] Validate dataset format (prompts as List[Dict])
|
||||
- [ ] Test reward functions on sample data
|
||||
- [ ] Calculate expected max_prompt_length from data
|
||||
- [ ] Choose appropriate num_generations based on GPU memory
|
||||
- [ ] Set up logging (wandb recommended)
|
||||
|
||||
**During Training:**
|
||||
- [ ] Monitor reward progression (should increase)
|
||||
- [ ] Check reward_std (should stay > 0.1)
|
||||
- [ ] Watch for OOM errors (reduce batch size if needed)
|
||||
- [ ] Sample generations every 50-100 steps
|
||||
- [ ] Validate format compliance on holdout set
|
||||
|
||||
**After Training:**
|
||||
- [ ] Merge LoRA weights if using PEFT
|
||||
- [ ] Test on diverse prompts
|
||||
- [ ] Compare to baseline model
|
||||
- [ ] Document reward weights and hyperparameters
|
||||
- [ ] Save reproducibility config
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting Guide
|
||||
|
||||
### Debugging Workflow
|
||||
1. **Isolate reward functions** - Test each independently
|
||||
2. **Check data distribution** - Ensure diversity in prompts
|
||||
3. **Reduce complexity** - Start with single reward, add gradually
|
||||
4. **Monitor generations** - Print samples every N steps
|
||||
5. **Validate extraction logic** - Ensure answer parsing works
|
||||
|
||||
### Quick Fixes
|
||||
```python
|
||||
# Debug reward function
|
||||
def debug_reward(completions, **kwargs):
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
for i, r in enumerate(responses[:2]): # Print first 2
|
||||
print(f"Response {i}: {r[:200]}...")
|
||||
return [1.0] * len(responses) # Dummy rewards
|
||||
|
||||
# Test without training
|
||||
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
|
||||
trainer.generate_completions(dataset[:1]) # Generate without updating
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## References and Resources
|
||||
|
||||
**Official Documentation:**
|
||||
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
|
||||
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
|
||||
- Unsloth Docs: https://docs.unsloth.ai/
|
||||
|
||||
**Example Repositories:**
|
||||
- Open R1 Implementation: https://github.com/huggingface/open-r1
|
||||
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
|
||||
**Recommended Reading:**
|
||||
- Progressive Disclosure Pattern for agent instructions
|
||||
- Reward shaping in RL (Ng et al.)
|
||||
- LoRA paper (Hu et al., 2021)
|
||||
|
||||
---
|
||||
|
||||
## Usage Instructions for Agents
|
||||
|
||||
When this skill is loaded:
|
||||
|
||||
1. **Read this entire file** before implementing GRPO training
|
||||
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
|
||||
3. **Use the templates** in `templates/` directory as starting points
|
||||
4. **Reference examples** in `examples/` for task-specific implementations
|
||||
5. **Follow the workflow** sequentially (don't skip steps)
|
||||
6. **Debug incrementally** - add one reward function at a time
|
||||
|
||||
**Critical Reminders:**
|
||||
- Always use multiple reward functions (3-5 is optimal)
|
||||
- Monitor reward metrics, not loss
|
||||
- Test reward functions before training
|
||||
- Start small (num_generations=4), scale up gradually
|
||||
- Save checkpoints frequently (every 100 steps)
|
||||
|
||||
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.
|
||||
|
||||
|
||||
|
||||
228
skills/mlops/grpo-rl-training/templates/basic_grpo_training.py
Normal file
228
skills/mlops/grpo-rl-training/templates/basic_grpo_training.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""
|
||||
Basic GRPO Training Template
|
||||
=============================
|
||||
|
||||
A minimal, production-ready template for GRPO training with TRL.
|
||||
Adapt this for your specific task by modifying:
|
||||
1. Dataset loading (get_dataset function)
|
||||
2. Reward functions (reward_*_func)
|
||||
3. System prompt (SYSTEM_PROMPT)
|
||||
4. Hyperparameters (GRPOConfig)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import re
|
||||
from datasets import load_dataset, Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
# ==================== CONFIGURATION ====================
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
OUTPUT_DIR = "outputs/grpo-model"
|
||||
MAX_PROMPT_LENGTH = 256
|
||||
MAX_COMPLETION_LENGTH = 512
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
# ==================== DATASET ====================
|
||||
|
||||
def get_dataset(split="train"):
|
||||
"""
|
||||
Load and prepare your dataset.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content
|
||||
- 'answer': str (ground truth, optional)
|
||||
"""
|
||||
# Example: GSM8K math dataset
|
||||
data = load_dataset('openai/gsm8k', 'main')[split]
|
||||
|
||||
def process_example(x):
|
||||
# Extract ground truth answer
|
||||
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
|
||||
|
||||
return {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
return data.map(process_example)
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
def extract_xml_tag(text: str, tag: str) -> str:
|
||||
"""Extract content between XML tags."""
|
||||
pattern = f'<{tag}>(.*?)</{tag}>'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract the final answer from structured output."""
|
||||
return extract_xml_tag(text, 'answer')
|
||||
|
||||
# ==================== REWARD FUNCTIONS ====================
|
||||
|
||||
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
||||
"""
|
||||
Reward correct answers.
|
||||
Weight: 2.0 (highest priority)
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Reward proper XML format.
|
||||
Weight: 0.5
|
||||
"""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
|
||||
|
||||
def incremental_format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Incremental reward for partial format compliance.
|
||||
Weight: up to 0.5
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.125
|
||||
if '</reasoning>' in r:
|
||||
score += 0.125
|
||||
if '<answer>' in r:
|
||||
score += 0.125
|
||||
if '</answer>' in r:
|
||||
score += 0.125
|
||||
|
||||
# Penalize extra content after closing tag
|
||||
if '</answer>' in r:
|
||||
extra = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra) * 0.001
|
||||
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== MODEL SETUP ====================
|
||||
|
||||
def setup_model_and_tokenizer():
|
||||
"""Load model and tokenizer with optimizations."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def get_peft_config():
|
||||
"""LoRA configuration for parameter-efficient training."""
|
||||
return LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# ==================== TRAINING ====================
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
|
||||
# Load data
|
||||
print("Loading dataset...")
|
||||
dataset = get_dataset()
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Setup model
|
||||
print("Loading model...")
|
||||
model, tokenizer = setup_model_and_tokenizer()
|
||||
|
||||
# Training configuration
|
||||
training_args = GRPOConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
run_name="grpo-training",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
|
||||
# GRPO specific
|
||||
num_generations=8,
|
||||
max_prompt_length=MAX_PROMPT_LENGTH,
|
||||
max_completion_length=MAX_COMPLETION_LENGTH,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
|
||||
# Optimization
|
||||
bf16=True,
|
||||
optim="adamw_8bit",
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Change to "none" to disable logging
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward_func,
|
||||
format_reward_func,
|
||||
correctness_reward_func,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=get_peft_config(),
|
||||
)
|
||||
|
||||
# Train
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print(f"Saving model to {OUTPUT_DIR}/final")
|
||||
trainer.save_model(f"{OUTPUT_DIR}/final")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
575
skills/mlops/guidance/SKILL.md
Normal file
575
skills/mlops/guidance/SKILL.md
Normal file
|
|
@ -0,0 +1,575 @@
|
|||
---
|
||||
name: guidance
|
||||
description: Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [guidance, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Prompt Engineering, Guidance, Constrained Generation, Structured Output, JSON Validation, Grammar, Microsoft Research, Format Enforcement, Multi-Step Workflows]
|
||||
|
||||
---
|
||||
|
||||
# Guidance: Constrained LLM Generation
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Guidance when you need to:
|
||||
- **Control LLM output syntax** with regex or grammars
|
||||
- **Guarantee valid JSON/XML/code** generation
|
||||
- **Reduce latency** vs traditional prompting approaches
|
||||
- **Enforce structured formats** (dates, emails, IDs, etc.)
|
||||
- **Build multi-step workflows** with Pythonic control flow
|
||||
- **Prevent invalid outputs** through grammatical constraints
|
||||
|
||||
**GitHub Stars**: 18,000+ | **From**: Microsoft Research
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base installation
|
||||
pip install guidance
|
||||
|
||||
# With specific backends
|
||||
pip install guidance[transformers] # Hugging Face models
|
||||
pip install guidance[llama_cpp] # llama.cpp models
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Structured Generation
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
# Load model (supports OpenAI, Transformers, llama.cpp)
|
||||
lm = models.OpenAI("gpt-4")
|
||||
|
||||
# Generate with constraints
|
||||
result = lm + "The capital of France is " + gen("capital", max_tokens=5)
|
||||
|
||||
print(result["capital"]) # "Paris"
|
||||
```
|
||||
|
||||
### With Anthropic Claude
|
||||
|
||||
```python
|
||||
from guidance import models, gen, system, user, assistant
|
||||
|
||||
# Configure Claude
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use context managers for chat format
|
||||
with system():
|
||||
lm += "You are a helpful assistant."
|
||||
|
||||
with user():
|
||||
lm += "What is the capital of France?"
|
||||
|
||||
with assistant():
|
||||
lm += gen(max_tokens=20)
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Context Managers
|
||||
|
||||
Guidance uses Pythonic context managers for chat-style interactions.
|
||||
|
||||
```python
|
||||
from guidance import system, user, assistant, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# System message
|
||||
with system():
|
||||
lm += "You are a JSON generation expert."
|
||||
|
||||
# User message
|
||||
with user():
|
||||
lm += "Generate a person object with name and age."
|
||||
|
||||
# Assistant response
|
||||
with assistant():
|
||||
lm += gen("response", max_tokens=100)
|
||||
|
||||
print(lm["response"])
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Natural chat flow
|
||||
- Clear role separation
|
||||
- Easy to read and maintain
|
||||
|
||||
### 2. Constrained Generation
|
||||
|
||||
Guidance ensures outputs match specified patterns using regex or grammars.
|
||||
|
||||
#### Regex Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Constrain to valid email format
|
||||
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
|
||||
# Constrain to date format (YYYY-MM-DD)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
|
||||
|
||||
# Constrain to phone number
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
|
||||
|
||||
print(lm["email"]) # Guaranteed valid email
|
||||
print(lm["date"]) # Guaranteed YYYY-MM-DD format
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- Regex converted to grammar at token level
|
||||
- Invalid tokens filtered during generation
|
||||
- Model can only produce matching outputs
|
||||
|
||||
#### Selection Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Constrain to specific choices
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
|
||||
# Multiple-choice selection
|
||||
lm += "Best answer: " + select(
|
||||
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
|
||||
name="answer"
|
||||
)
|
||||
|
||||
print(lm["sentiment"]) # One of: positive, negative, neutral
|
||||
print(lm["answer"]) # One of: A, B, C, or D
|
||||
```
|
||||
|
||||
### 3. Token Healing
|
||||
|
||||
Guidance automatically "heals" token boundaries between prompt and generation.
|
||||
|
||||
**Problem:** Tokenization creates unnatural boundaries.
|
||||
|
||||
```python
|
||||
# Without token healing
|
||||
prompt = "The capital of France is "
|
||||
# Last token: " is "
|
||||
# First generated token might be " Par" (with leading space)
|
||||
# Result: "The capital of France is Paris" (double space!)
|
||||
```
|
||||
|
||||
**Solution:** Guidance backs up one token and regenerates.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Token healing enabled by default
|
||||
lm += "The capital of France is " + gen("capital", max_tokens=5)
|
||||
# Result: "The capital of France is Paris" (correct spacing)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Natural text boundaries
|
||||
- No awkward spacing issues
|
||||
- Better model performance (sees natural token sequences)
|
||||
|
||||
### 4. Grammar-Based Generation
|
||||
|
||||
Define complex structures using context-free grammars.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# JSON grammar (simplified)
|
||||
json_grammar = """
|
||||
{
|
||||
"name": <gen name regex="[A-Za-z ]+" max_tokens=20>,
|
||||
"age": <gen age regex="[0-9]+" max_tokens=3>,
|
||||
"email": <gen email regex="[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}" max_tokens=50>
|
||||
}
|
||||
"""
|
||||
|
||||
# Generate valid JSON
|
||||
lm += gen("person", grammar=json_grammar)
|
||||
|
||||
print(lm["person"]) # Guaranteed valid JSON structure
|
||||
```
|
||||
|
||||
**Use cases:**
|
||||
- Complex structured outputs
|
||||
- Nested data structures
|
||||
- Programming language syntax
|
||||
- Domain-specific languages
|
||||
|
||||
### 5. Guidance Functions
|
||||
|
||||
Create reusable generation patterns with the `@guidance` decorator.
|
||||
|
||||
```python
|
||||
from guidance import guidance, gen, models
|
||||
|
||||
@guidance
|
||||
def generate_person(lm):
|
||||
"""Generate a person with name and age."""
|
||||
lm += "Name: " + gen("name", max_tokens=20, stop="\n")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+", max_tokens=3)
|
||||
return lm
|
||||
|
||||
# Use the function
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_person(lm)
|
||||
|
||||
print(lm["name"])
|
||||
print(lm["age"])
|
||||
```
|
||||
|
||||
**Stateful Functions:**
|
||||
|
||||
```python
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question, tools, max_rounds=5):
|
||||
"""ReAct agent with tool use."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for i in range(max_rounds):
|
||||
# Thought
|
||||
lm += f"Thought {i+1}: " + gen("thought", stop="\n")
|
||||
|
||||
# Action
|
||||
lm += "\nAction: " + select(list(tools.keys()), name="action")
|
||||
|
||||
# Execute tool
|
||||
tool_result = tools[lm["action"]]()
|
||||
lm += f"\nObservation: {tool_result}\n\n"
|
||||
|
||||
# Check if done
|
||||
lm += "Done? " + select(["Yes", "No"], name="done")
|
||||
if lm["done"] == "Yes":
|
||||
break
|
||||
|
||||
# Final answer
|
||||
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
|
||||
return lm
|
||||
```
|
||||
|
||||
## Backend Configuration
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key" # Or set ANTHROPIC_API_KEY env var
|
||||
)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o-mini",
|
||||
api_key="your-api-key" # Or set OPENAI_API_KEY env var
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (Transformers)
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda" # Or "cpu"
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (llama.cpp)
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: JSON Generation
|
||||
|
||||
```python
|
||||
from guidance import models, gen, system, user, assistant
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You generate valid JSON."
|
||||
|
||||
with user():
|
||||
lm += "Generate a user profile with name, age, and email."
|
||||
|
||||
with assistant():
|
||||
lm += """{
|
||||
"name": """ + gen("name", regex=r'"[A-Za-z ]+"', max_tokens=30) + """,
|
||||
"age": """ + gen("age", regex=r"[0-9]+", max_tokens=3) + """,
|
||||
"email": """ + gen("email", regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"', max_tokens=50) + """
|
||||
}"""
|
||||
|
||||
print(lm) # Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Pattern 2: Classification
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
text = "This product is amazing! I love it."
|
||||
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]+", max_tokens=3) + "%"
|
||||
|
||||
print(f"Sentiment: {lm['sentiment']}")
|
||||
print(f"Confidence: {lm['confidence']}%")
|
||||
```
|
||||
|
||||
### Pattern 3: Multi-Step Reasoning
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def chain_of_thought(lm, question):
|
||||
"""Generate answer with step-by-step reasoning."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
# Generate multiple reasoning steps
|
||||
for i in range(3):
|
||||
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Final answer
|
||||
lm += "\nTherefore, the answer is: " + gen("answer", max_tokens=50)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = chain_of_thought(lm, "What is 15% of 200?")
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Pattern 4: ReAct Agent
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select, guidance
|
||||
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question):
|
||||
"""ReAct agent with tool use."""
|
||||
tools = {
|
||||
"calculator": lambda expr: eval(expr),
|
||||
"search": lambda query: f"Search results for: {query}",
|
||||
}
|
||||
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for round in range(5):
|
||||
# Thought
|
||||
lm += f"Thought: " + gen("thought", stop="\n") + "\n"
|
||||
|
||||
# Action selection
|
||||
lm += "Action: " + select(["calculator", "search", "answer"], name="action")
|
||||
|
||||
if lm["action"] == "answer":
|
||||
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
|
||||
break
|
||||
|
||||
# Action input
|
||||
lm += "\nAction Input: " + gen("action_input", stop="\n") + "\n"
|
||||
|
||||
# Execute tool
|
||||
if lm["action"] in tools:
|
||||
result = tools[lm["action"]](lm["action_input"])
|
||||
lm += f"Observation: {result}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = react_agent(lm, "What is 25 * 4 + 10?")
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Pattern 5: Data Extraction
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract structured entities from text."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
# Extract person
|
||||
lm += "Person: " + gen("person", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
# Extract organization
|
||||
lm += "Organization: " + gen("organization", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
# Extract date
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}", max_tokens=10) + "\n"
|
||||
|
||||
# Extract location
|
||||
lm += "Location: " + gen("location", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = "Tim Cook announced at Apple Park on 2024-09-15 in Cupertino."
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = extract_entities(lm, text)
|
||||
|
||||
print(f"Person: {lm['person']}")
|
||||
print(f"Organization: {lm['organization']}")
|
||||
print(f"Date: {lm['date']}")
|
||||
print(f"Location: {lm['location']}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Regex for Format Validation
|
||||
|
||||
```python
|
||||
# ✅ Good: Regex ensures valid format
|
||||
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
|
||||
# ❌ Bad: Free generation may produce invalid emails
|
||||
lm += "Email: " + gen("email", max_tokens=50)
|
||||
```
|
||||
|
||||
### 2. Use select() for Fixed Categories
|
||||
|
||||
```python
|
||||
# ✅ Good: Guaranteed valid category
|
||||
lm += "Status: " + select(["pending", "approved", "rejected"], name="status")
|
||||
|
||||
# ❌ Bad: May generate typos or invalid values
|
||||
lm += "Status: " + gen("status", max_tokens=20)
|
||||
```
|
||||
|
||||
### 3. Leverage Token Healing
|
||||
|
||||
```python
|
||||
# Token healing is enabled by default
|
||||
# No special action needed - just concatenate naturally
|
||||
lm += "The capital is " + gen("capital") # Automatic healing
|
||||
```
|
||||
|
||||
### 4. Use stop Sequences
|
||||
|
||||
```python
|
||||
# ✅ Good: Stop at newline for single-line outputs
|
||||
lm += "Name: " + gen("name", stop="\n")
|
||||
|
||||
# ❌ Bad: May generate multiple lines
|
||||
lm += "Name: " + gen("name", max_tokens=50)
|
||||
```
|
||||
|
||||
### 5. Create Reusable Functions
|
||||
|
||||
```python
|
||||
# ✅ Good: Reusable pattern
|
||||
@guidance
|
||||
def generate_person(lm):
|
||||
lm += "Name: " + gen("name", stop="\n")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
|
||||
return lm
|
||||
|
||||
# Use multiple times
|
||||
lm = generate_person(lm)
|
||||
lm += "\n\n"
|
||||
lm = generate_person(lm)
|
||||
```
|
||||
|
||||
### 6. Balance Constraints
|
||||
|
||||
```python
|
||||
# ✅ Good: Reasonable constraints
|
||||
lm += gen("name", regex=r"[A-Za-z ]+", max_tokens=30)
|
||||
|
||||
# ❌ Too strict: May fail or be very slow
|
||||
lm += gen("name", regex=r"^(John|Jane)$", max_tokens=10)
|
||||
```
|
||||
|
||||
## Comparison to Alternatives
|
||||
|
||||
| Feature | Guidance | Instructor | Outlines | LMQL |
|
||||
|---------|----------|------------|----------|------|
|
||||
| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes |
|
||||
| Grammar Support | ✅ CFG | ❌ No | ✅ CFG | ✅ CFG |
|
||||
| Pydantic Validation | ❌ No | ✅ Yes | ✅ Yes | ❌ No |
|
||||
| Token Healing | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
|
||||
| Local Models | ✅ Yes | ⚠️ Limited | ✅ Yes | ✅ Yes |
|
||||
| API Models | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes |
|
||||
| Pythonic Syntax | ✅ Yes | ✅ Yes | ✅ Yes | ❌ SQL-like |
|
||||
| Learning Curve | Low | Low | Medium | High |
|
||||
|
||||
**When to choose Guidance:**
|
||||
- Need regex/grammar constraints
|
||||
- Want token healing
|
||||
- Building complex workflows with control flow
|
||||
- Using local models (Transformers, llama.cpp)
|
||||
- Prefer Pythonic syntax
|
||||
|
||||
**When to choose alternatives:**
|
||||
- Instructor: Need Pydantic validation with automatic retrying
|
||||
- Outlines: Need JSON schema validation
|
||||
- LMQL: Prefer declarative query syntax
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
**Latency Reduction:**
|
||||
- 30-50% faster than traditional prompting for constrained outputs
|
||||
- Token healing reduces unnecessary regeneration
|
||||
- Grammar constraints prevent invalid token generation
|
||||
|
||||
**Memory Usage:**
|
||||
- Minimal overhead vs unconstrained generation
|
||||
- Grammar compilation cached after first use
|
||||
- Efficient token filtering at inference time
|
||||
|
||||
**Token Efficiency:**
|
||||
- Prevents wasted tokens on invalid outputs
|
||||
- No need for retry loops
|
||||
- Direct path to valid outputs
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://guidance.readthedocs.io
|
||||
- **GitHub**: https://github.com/guidance-ai/guidance (18k+ stars)
|
||||
- **Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
|
||||
- **Discord**: Community support available
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/constraints.md` - Comprehensive regex and grammar patterns
|
||||
- `references/backends.md` - Backend-specific configuration
|
||||
- `references/examples.md` - Production-ready examples
|
||||
|
||||
|
||||
554
skills/mlops/guidance/references/backends.md
Normal file
554
skills/mlops/guidance/references/backends.md
Normal file
|
|
@ -0,0 +1,554 @@
|
|||
# Backend Configuration Guide
|
||||
|
||||
Complete guide to configuring Guidance with different LLM backends.
|
||||
|
||||
## Table of Contents
|
||||
- API-Based Models (Anthropic, OpenAI)
|
||||
- Local Models (Transformers, llama.cpp)
|
||||
- Backend Comparison
|
||||
- Performance Tuning
|
||||
- Advanced Configuration
|
||||
|
||||
## API-Based Models
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
# Using environment variable
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
# Reads ANTHROPIC_API_KEY from environment
|
||||
|
||||
# Explicit API key
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key-here"
|
||||
)
|
||||
```
|
||||
|
||||
#### Available Models
|
||||
|
||||
```python
|
||||
# Claude 3.5 Sonnet (Latest, recommended)
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Claude 3.7 Sonnet (Fast, cost-effective)
|
||||
lm = models.Anthropic("claude-sonnet-3.7-20250219")
|
||||
|
||||
# Claude 3 Opus (Most capable)
|
||||
lm = models.Anthropic("claude-3-opus-20240229")
|
||||
|
||||
# Claude 3.5 Haiku (Fastest, cheapest)
|
||||
lm = models.Anthropic("claude-3-5-haiku-20241022")
|
||||
```
|
||||
|
||||
#### Configuration Options
|
||||
|
||||
```python
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key",
|
||||
max_tokens=4096, # Max tokens to generate
|
||||
temperature=0.7, # Sampling temperature (0-1)
|
||||
top_p=0.9, # Nucleus sampling
|
||||
timeout=30, # Request timeout (seconds)
|
||||
max_retries=3 # Retry failed requests
|
||||
)
|
||||
```
|
||||
|
||||
#### With Context Managers
|
||||
|
||||
```python
|
||||
from guidance import models, system, user, assistant, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You are a helpful assistant."
|
||||
|
||||
with user():
|
||||
lm += "What is the capital of France?"
|
||||
|
||||
with assistant():
|
||||
lm += gen(max_tokens=50)
|
||||
|
||||
print(lm)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
# Using environment variable
|
||||
lm = models.OpenAI("gpt-4o")
|
||||
# Reads OPENAI_API_KEY from environment
|
||||
|
||||
# Explicit API key
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o",
|
||||
api_key="your-api-key-here"
|
||||
)
|
||||
```
|
||||
|
||||
#### Available Models
|
||||
|
||||
```python
|
||||
# GPT-4o (Latest, multimodal)
|
||||
lm = models.OpenAI("gpt-4o")
|
||||
|
||||
# GPT-4o Mini (Fast, cost-effective)
|
||||
lm = models.OpenAI("gpt-4o-mini")
|
||||
|
||||
# GPT-4 Turbo
|
||||
lm = models.OpenAI("gpt-4-turbo")
|
||||
|
||||
# GPT-3.5 Turbo (Cheapest)
|
||||
lm = models.OpenAI("gpt-3.5-turbo")
|
||||
```
|
||||
|
||||
#### Configuration Options
|
||||
|
||||
```python
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o-mini",
|
||||
api_key="your-api-key",
|
||||
max_tokens=2048,
|
||||
temperature=0.7,
|
||||
top_p=1.0,
|
||||
frequency_penalty=0.0,
|
||||
presence_penalty=0.0,
|
||||
timeout=30
|
||||
)
|
||||
```
|
||||
|
||||
#### Chat Format
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.OpenAI("gpt-4o-mini")
|
||||
|
||||
# OpenAI uses chat format
|
||||
lm += [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is 2+2?"}
|
||||
]
|
||||
|
||||
# Generate response
|
||||
lm += gen(max_tokens=50)
|
||||
```
|
||||
|
||||
### Azure OpenAI
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
lm = models.AzureOpenAI(
|
||||
model="gpt-4o",
|
||||
azure_endpoint="https://your-resource.openai.azure.com/",
|
||||
api_key="your-azure-api-key",
|
||||
api_version="2024-02-15-preview",
|
||||
deployment_name="your-deployment-name"
|
||||
)
|
||||
```
|
||||
|
||||
## Local Models
|
||||
|
||||
### Transformers (Hugging Face)
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Load model from Hugging Face
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
```
|
||||
|
||||
#### GPU Configuration
|
||||
|
||||
```python
|
||||
# Use GPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Use specific GPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda:0" # GPU 0
|
||||
)
|
||||
|
||||
# Use CPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cpu"
|
||||
)
|
||||
```
|
||||
|
||||
#### Advanced Configuration
|
||||
|
||||
```python
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda",
|
||||
torch_dtype="float16", # Use FP16 (faster, less memory)
|
||||
load_in_8bit=True, # 8-bit quantization
|
||||
max_memory={0: "20GB"}, # GPU memory limit
|
||||
offload_folder="./offload" # Offload to disk if needed
|
||||
)
|
||||
```
|
||||
|
||||
#### Popular Models
|
||||
|
||||
```python
|
||||
# Phi-4 (Microsoft)
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
lm = Transformers("microsoft/Phi-3-medium-4k-instruct")
|
||||
|
||||
# Llama 3 (Meta)
|
||||
lm = Transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
lm = Transformers("meta-llama/Llama-3.1-70B-Instruct")
|
||||
|
||||
# Mistral (Mistral AI)
|
||||
lm = Transformers("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
lm = Transformers("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
|
||||
# Qwen (Alibaba)
|
||||
lm = Transformers("Qwen/Qwen2.5-7B-Instruct")
|
||||
|
||||
# Gemma (Google)
|
||||
lm = Transformers("google/gemma-2-9b-it")
|
||||
```
|
||||
|
||||
#### Generation Configuration
|
||||
|
||||
```python
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Configure generation
|
||||
from guidance import gen
|
||||
|
||||
result = lm + gen(
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=50,
|
||||
repetition_penalty=1.1
|
||||
)
|
||||
```
|
||||
|
||||
### llama.cpp
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
# Load GGUF model
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096 # Context window
|
||||
)
|
||||
```
|
||||
|
||||
#### GPU Configuration
|
||||
|
||||
```python
|
||||
# Use GPU acceleration
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35, # Offload 35 layers to GPU
|
||||
n_threads=8 # CPU threads for remaining layers
|
||||
)
|
||||
|
||||
# Full GPU offload
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=-1 # Offload all layers
|
||||
)
|
||||
```
|
||||
|
||||
#### Advanced Configuration
|
||||
|
||||
```python
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/llama-3.1-8b-instruct.Q4_K_M.gguf",
|
||||
n_ctx=8192, # Context window (tokens)
|
||||
n_gpu_layers=35, # GPU layers
|
||||
n_threads=8, # CPU threads
|
||||
n_batch=512, # Batch size for prompt processing
|
||||
use_mmap=True, # Memory-map the model file
|
||||
use_mlock=False, # Lock model in RAM
|
||||
seed=42, # Random seed
|
||||
verbose=False # Suppress verbose output
|
||||
)
|
||||
```
|
||||
|
||||
#### Quantized Models
|
||||
|
||||
```python
|
||||
# Q4_K_M (4-bit, recommended for most cases)
|
||||
lm = LlamaCpp("/path/to/model.Q4_K_M.gguf")
|
||||
|
||||
# Q5_K_M (5-bit, better quality)
|
||||
lm = LlamaCpp("/path/to/model.Q5_K_M.gguf")
|
||||
|
||||
# Q8_0 (8-bit, high quality)
|
||||
lm = LlamaCpp("/path/to/model.Q8_0.gguf")
|
||||
|
||||
# F16 (16-bit float, highest quality)
|
||||
lm = LlamaCpp("/path/to/model.F16.gguf")
|
||||
```
|
||||
|
||||
#### Popular GGUF Models
|
||||
|
||||
```python
|
||||
# Llama 3.1
|
||||
lm = LlamaCpp("llama-3.1-8b-instruct.Q4_K_M.gguf")
|
||||
|
||||
# Mistral
|
||||
lm = LlamaCpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf")
|
||||
|
||||
# Phi-4
|
||||
lm = LlamaCpp("phi-4-mini-instruct.Q4_K_M.gguf")
|
||||
```
|
||||
|
||||
## Backend Comparison
|
||||
|
||||
### Feature Matrix
|
||||
|
||||
| Feature | Anthropic | OpenAI | Transformers | llama.cpp |
|
||||
|---------|-----------|--------|--------------|-----------|
|
||||
| Constrained Generation | ✅ Full | ✅ Full | ✅ Full | ✅ Full |
|
||||
| Token Healing | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
|
||||
| Streaming | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
|
||||
| GPU Support | N/A | N/A | ✅ Yes | ✅ Yes |
|
||||
| Quantization | N/A | N/A | ✅ Yes | ✅ Yes |
|
||||
| Cost | $$$ | $$$ | Free | Free |
|
||||
| Latency | Low | Low | Medium | Low |
|
||||
| Setup Difficulty | Easy | Easy | Medium | Medium |
|
||||
|
||||
### Performance Characteristics
|
||||
|
||||
**Anthropic Claude:**
|
||||
- **Latency**: 200-500ms (API call)
|
||||
- **Throughput**: Limited by API rate limits
|
||||
- **Cost**: $3-15 per 1M input tokens
|
||||
- **Best for**: Production systems, high-quality outputs
|
||||
|
||||
**OpenAI:**
|
||||
- **Latency**: 200-400ms (API call)
|
||||
- **Throughput**: Limited by API rate limits
|
||||
- **Cost**: $0.15-30 per 1M input tokens
|
||||
- **Best for**: Cost-sensitive production, gpt-4o-mini
|
||||
|
||||
**Transformers:**
|
||||
- **Latency**: 50-200ms (local inference)
|
||||
- **Throughput**: GPU-dependent (10-100 tokens/sec)
|
||||
- **Cost**: Hardware cost only
|
||||
- **Best for**: Privacy-sensitive, high-volume, experimentation
|
||||
|
||||
**llama.cpp:**
|
||||
- **Latency**: 30-150ms (local inference)
|
||||
- **Throughput**: Hardware-dependent (20-150 tokens/sec)
|
||||
- **Cost**: Hardware cost only
|
||||
- **Best for**: Edge deployment, Apple Silicon, CPU inference
|
||||
|
||||
### Memory Requirements
|
||||
|
||||
**Transformers (FP16):**
|
||||
- 7B model: ~14GB GPU VRAM
|
||||
- 13B model: ~26GB GPU VRAM
|
||||
- 70B model: ~140GB GPU VRAM (multi-GPU)
|
||||
|
||||
**llama.cpp (Q4_K_M):**
|
||||
- 7B model: ~4.5GB RAM
|
||||
- 13B model: ~8GB RAM
|
||||
- 70B model: ~40GB RAM
|
||||
|
||||
**Optimization Tips:**
|
||||
- Use quantized models (Q4_K_M) for lower memory
|
||||
- Use GPU offloading for faster inference
|
||||
- Use CPU inference for smaller models (<7B)
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### API Models (Anthropic, OpenAI)
|
||||
|
||||
#### Reduce Latency
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use lower max_tokens (faster response)
|
||||
lm += gen(max_tokens=100) # Instead of 1000
|
||||
|
||||
# Use streaming (perceived latency reduction)
|
||||
for chunk in lm.stream(gen(max_tokens=500)):
|
||||
print(chunk, end="", flush=True)
|
||||
```
|
||||
|
||||
#### Reduce Cost
|
||||
|
||||
```python
|
||||
# Use cheaper models
|
||||
lm = models.Anthropic("claude-3-5-haiku-20241022") # vs Sonnet
|
||||
lm = models.OpenAI("gpt-4o-mini") # vs gpt-4o
|
||||
|
||||
# Reduce context size
|
||||
# - Keep prompts concise
|
||||
# - Avoid large few-shot examples
|
||||
# - Use max_tokens limits
|
||||
```
|
||||
|
||||
### Local Models (Transformers, llama.cpp)
|
||||
|
||||
#### Optimize GPU Usage
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Use FP16 for 2x speedup
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
torch_dtype="float16"
|
||||
)
|
||||
|
||||
# Use 8-bit quantization for 4x memory reduction
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
load_in_8bit=True
|
||||
)
|
||||
|
||||
# Use flash attention (requires flash-attn package)
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
use_flash_attention_2=True
|
||||
)
|
||||
```
|
||||
|
||||
#### Optimize llama.cpp
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
# Maximize GPU layers
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_gpu_layers=-1 # All layers on GPU
|
||||
)
|
||||
|
||||
# Optimize batch size
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_batch=512, # Larger batch = faster prompt processing
|
||||
n_gpu_layers=-1
|
||||
)
|
||||
|
||||
# Use Metal (Apple Silicon)
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_gpu_layers=-1, # Use Metal GPU acceleration
|
||||
use_mmap=True
|
||||
)
|
||||
```
|
||||
|
||||
#### Batch Processing
|
||||
|
||||
```python
|
||||
# Process multiple requests efficiently
|
||||
requests = [
|
||||
"What is 2+2?",
|
||||
"What is the capital of France?",
|
||||
"What is photosynthesis?"
|
||||
]
|
||||
|
||||
# Bad: Sequential processing
|
||||
for req in requests:
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
lm += req + gen(max_tokens=50)
|
||||
|
||||
# Good: Reuse loaded model
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
for req in requests:
|
||||
lm += req + gen(max_tokens=50)
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Custom Model Configurations
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Load custom model
|
||||
tokenizer = AutoTokenizer.from_pretrained("your-model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"your-model",
|
||||
device_map="auto",
|
||||
torch_dtype="float16"
|
||||
)
|
||||
|
||||
# Use with Guidance
|
||||
lm = Transformers(model=model, tokenizer=tokenizer)
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# API keys
|
||||
export ANTHROPIC_API_KEY="sk-ant-..."
|
||||
export OPENAI_API_KEY="sk-..."
|
||||
|
||||
# Transformers cache
|
||||
export HF_HOME="/path/to/cache"
|
||||
export TRANSFORMERS_CACHE="/path/to/cache"
|
||||
|
||||
# GPU selection
|
||||
export CUDA_VISIBLE_DEVICES=0,1 # Use GPU 0 and 1
|
||||
```
|
||||
|
||||
### Debugging
|
||||
|
||||
```python
|
||||
# Enable verbose logging
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Check backend info
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
print(f"Model: {lm.model_name}")
|
||||
print(f"Backend: {lm.backend}")
|
||||
|
||||
# Check GPU usage (Transformers)
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct", device="cuda")
|
||||
print(f"Device: {lm.device}")
|
||||
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Anthropic Docs**: https://docs.anthropic.com
|
||||
- **OpenAI Docs**: https://platform.openai.com/docs
|
||||
- **Hugging Face Models**: https://huggingface.co/models
|
||||
- **llama.cpp**: https://github.com/ggerganov/llama.cpp
|
||||
- **GGUF Models**: https://huggingface.co/models?library=gguf
|
||||
674
skills/mlops/guidance/references/constraints.md
Normal file
674
skills/mlops/guidance/references/constraints.md
Normal file
|
|
@ -0,0 +1,674 @@
|
|||
# Comprehensive Constraint Patterns
|
||||
|
||||
Guide to regex constraints, grammar-based generation, and token healing in Guidance.
|
||||
|
||||
## Table of Contents
|
||||
- Regex Constraints
|
||||
- Grammar-Based Generation
|
||||
- Token Healing
|
||||
- Selection Constraints
|
||||
- Complex Patterns
|
||||
- Performance Optimization
|
||||
|
||||
## Regex Constraints
|
||||
|
||||
### Basic Patterns
|
||||
|
||||
#### Numeric Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Integer (positive)
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+")
|
||||
|
||||
# Integer (with negatives)
|
||||
lm += "Temperature: " + gen("temp", regex=r"-?[0-9]+")
|
||||
|
||||
# Float (positive)
|
||||
lm += "Price: $" + gen("price", regex=r"[0-9]+\.[0-9]{2}")
|
||||
|
||||
# Float (with negatives and optional decimals)
|
||||
lm += "Value: " + gen("value", regex=r"-?[0-9]+(\.[0-9]+)?")
|
||||
|
||||
# Percentage (0-100)
|
||||
lm += "Progress: " + gen("progress", regex=r"(100|[0-9]{1,2})")
|
||||
|
||||
# Range (1-5 stars)
|
||||
lm += "Rating: " + gen("rating", regex=r"[1-5]") + " stars"
|
||||
```
|
||||
|
||||
#### Text Constraints
|
||||
|
||||
```python
|
||||
# Alphabetic only
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z]+")
|
||||
|
||||
# Alphabetic with spaces
|
||||
lm += "Full Name: " + gen("full_name", regex=r"[A-Za-z ]+")
|
||||
|
||||
# Alphanumeric
|
||||
lm += "Username: " + gen("username", regex=r"[A-Za-z0-9_]+")
|
||||
|
||||
# Capitalized words
|
||||
lm += "Title: " + gen("title", regex=r"[A-Z][a-z]+( [A-Z][a-z]+)*")
|
||||
|
||||
# Lowercase only
|
||||
lm += "Code: " + gen("code", regex=r"[a-z0-9-]+")
|
||||
|
||||
# Specific length
|
||||
lm += "ID: " + gen("id", regex=r"[A-Z]{3}-[0-9]{6}") # e.g., "ABC-123456"
|
||||
```
|
||||
|
||||
#### Date and Time Constraints
|
||||
|
||||
```python
|
||||
# Date (YYYY-MM-DD)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
|
||||
|
||||
# Date (MM/DD/YYYY)
|
||||
lm += "Date: " + gen("date_us", regex=r"\d{2}/\d{2}/\d{4}")
|
||||
|
||||
# Time (HH:MM)
|
||||
lm += "Time: " + gen("time", regex=r"\d{2}:\d{2}")
|
||||
|
||||
# Time (HH:MM:SS)
|
||||
lm += "Time: " + gen("time_full", regex=r"\d{2}:\d{2}:\d{2}")
|
||||
|
||||
# ISO 8601 datetime
|
||||
lm += "Timestamp: " + gen(
|
||||
"timestamp",
|
||||
regex=r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z"
|
||||
)
|
||||
|
||||
# Year (YYYY)
|
||||
lm += "Year: " + gen("year", regex=r"(19|20)\d{2}")
|
||||
|
||||
# Month name
|
||||
lm += "Month: " + gen(
|
||||
"month",
|
||||
regex=r"(January|February|March|April|May|June|July|August|September|October|November|December)"
|
||||
)
|
||||
```
|
||||
|
||||
#### Contact Information
|
||||
|
||||
```python
|
||||
# Email
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
|
||||
)
|
||||
|
||||
# Phone (US format)
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
|
||||
|
||||
# Phone (international format)
|
||||
lm += "Phone: " + gen("phone_intl", regex=r"\+[0-9]{1,3}-[0-9]{1,14}")
|
||||
|
||||
# ZIP code (US)
|
||||
lm += "ZIP: " + gen("zip", regex=r"\d{5}(-\d{4})?")
|
||||
|
||||
# Postal code (Canada)
|
||||
lm += "Postal: " + gen("postal", regex=r"[A-Z]\d[A-Z] \d[A-Z]\d")
|
||||
|
||||
# URL
|
||||
lm += "URL: " + gen(
|
||||
"url",
|
||||
regex=r"https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/[a-zA-Z0-9._~:/?#\[\]@!$&'()*+,;=-]*)?"
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Patterns
|
||||
|
||||
#### JSON Field Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# String field with quotes
|
||||
lm += '"name": ' + gen("name", regex=r'"[A-Za-z ]+"')
|
||||
|
||||
# Numeric field (no quotes)
|
||||
lm += '"age": ' + gen("age", regex=r"[0-9]+")
|
||||
|
||||
# Boolean field
|
||||
lm += '"active": ' + gen("active", regex=r"(true|false)")
|
||||
|
||||
# Null field
|
||||
lm += '"optional": ' + gen("optional", regex=r"(null|[0-9]+)")
|
||||
|
||||
# Array of strings
|
||||
lm += '"tags": [' + gen(
|
||||
"tags",
|
||||
regex=r'"[a-z]+"(, "[a-z]+")*'
|
||||
) + ']'
|
||||
|
||||
# Complete JSON object
|
||||
lm += """{
|
||||
"name": """ + gen("name", regex=r'"[A-Za-z ]+"') + """,
|
||||
"age": """ + gen("age", regex=r"[0-9]+") + """,
|
||||
"email": """ + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + """
|
||||
}"""
|
||||
```
|
||||
|
||||
#### Code Patterns
|
||||
|
||||
```python
|
||||
# Python variable name
|
||||
lm += "Variable: " + gen("var", regex=r"[a-z_][a-z0-9_]*")
|
||||
|
||||
# Python function name
|
||||
lm += "Function: " + gen("func", regex=r"[a-z_][a-z0-9_]*")
|
||||
|
||||
# Hex color code
|
||||
lm += "Color: #" + gen("color", regex=r"[0-9A-Fa-f]{6}")
|
||||
|
||||
# UUID
|
||||
lm += "UUID: " + gen(
|
||||
"uuid",
|
||||
regex=r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
)
|
||||
|
||||
# Git commit hash (short)
|
||||
lm += "Commit: " + gen("commit", regex=r"[0-9a-f]{7}")
|
||||
|
||||
# Semantic version
|
||||
lm += "Version: " + gen("version", regex=r"[0-9]+\.[0-9]+\.[0-9]+")
|
||||
|
||||
# IP address (IPv4)
|
||||
lm += "IP: " + gen(
|
||||
"ip",
|
||||
regex=r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
|
||||
)
|
||||
```
|
||||
|
||||
#### Domain-Specific Patterns
|
||||
|
||||
```python
|
||||
# Credit card number
|
||||
lm += "Card: " + gen("card", regex=r"\d{4}-\d{4}-\d{4}-\d{4}")
|
||||
|
||||
# Social Security Number (US)
|
||||
lm += "SSN: " + gen("ssn", regex=r"\d{3}-\d{2}-\d{4}")
|
||||
|
||||
# ISBN-13
|
||||
lm += "ISBN: " + gen("isbn", regex=r"978-\d{1,5}-\d{1,7}-\d{1,7}-\d")
|
||||
|
||||
# License plate (US)
|
||||
lm += "Plate: " + gen("plate", regex=r"[A-Z]{3}-\d{4}")
|
||||
|
||||
# Currency amount
|
||||
lm += "Amount: $" + gen("amount", regex=r"[0-9]{1,3}(,[0-9]{3})*\.[0-9]{2}")
|
||||
|
||||
# Percentage with decimal
|
||||
lm += "Rate: " + gen("rate", regex=r"[0-9]+\.[0-9]{1,2}%")
|
||||
```
|
||||
|
||||
## Grammar-Based Generation
|
||||
|
||||
### JSON Grammar
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def json_object(lm):
|
||||
"""Generate valid JSON object."""
|
||||
lm += "{\n"
|
||||
|
||||
# Name field (required)
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
|
||||
# Age field (required)
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
|
||||
|
||||
# Email field (required)
|
||||
lm += ' "email": ' + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + ",\n"
|
||||
|
||||
# Active field (required, boolean)
|
||||
lm += ' "active": ' + gen("active", regex=r"(true|false)") + "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = json_object(lm)
|
||||
print(lm) # Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Nested JSON Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def nested_json(lm):
|
||||
"""Generate nested JSON structure."""
|
||||
lm += "{\n"
|
||||
|
||||
# User object
|
||||
lm += ' "user": {\n'
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Address object
|
||||
lm += ' "address": {\n'
|
||||
lm += ' "street": ' + gen("street", regex=r'"[A-Za-z0-9 ]+"') + ",\n"
|
||||
lm += ' "city": ' + gen("city", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "zip": ' + gen("zip", regex=r'"\d{5}"') + "\n"
|
||||
lm += " }\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
```
|
||||
|
||||
### Array Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def json_array(lm, count=3):
|
||||
"""Generate JSON array with fixed count."""
|
||||
lm += "[\n"
|
||||
|
||||
for i in range(count):
|
||||
lm += " {\n"
|
||||
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + "\n"
|
||||
lm += " }"
|
||||
if i < count - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "]"
|
||||
return lm
|
||||
```
|
||||
|
||||
### XML Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def xml_document(lm):
|
||||
"""Generate valid XML document."""
|
||||
lm += '<?xml version="1.0"?>\n'
|
||||
lm += "<person>\n"
|
||||
|
||||
# Name element
|
||||
lm += " <name>" + gen("name", regex=r"[A-Za-z ]+") + "</name>\n"
|
||||
|
||||
# Age element
|
||||
lm += " <age>" + gen("age", regex=r"[0-9]+") + "</age>\n"
|
||||
|
||||
# Email element
|
||||
lm += " <email>" + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
|
||||
) + "</email>\n"
|
||||
|
||||
lm += "</person>"
|
||||
return lm
|
||||
```
|
||||
|
||||
### CSV Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def csv_row(lm):
|
||||
"""Generate CSV row."""
|
||||
lm += gen("name", regex=r"[A-Za-z ]+") + ","
|
||||
lm += gen("age", regex=r"[0-9]+") + ","
|
||||
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def csv_document(lm, rows=5):
|
||||
"""Generate complete CSV."""
|
||||
# Header
|
||||
lm += "Name,Age,Email\n"
|
||||
|
||||
# Rows
|
||||
for i in range(rows):
|
||||
lm = csv_row(lm)
|
||||
if i < rows - 1:
|
||||
lm += "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Token Healing
|
||||
|
||||
### How Token Healing Works
|
||||
|
||||
**Problem:** Tokenization creates unnatural boundaries.
|
||||
|
||||
```python
|
||||
# Example without token healing
|
||||
prompt = "The capital of France is "
|
||||
# Tokenization: ["The", " capital", " of", " France", " is", " "]
|
||||
# Model sees last token: " "
|
||||
# First generated token might include leading space: " Paris"
|
||||
# Result: "The capital of France is Paris" (double space)
|
||||
```
|
||||
|
||||
**Solution:** Guidance backs up and regenerates the last token.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Token healing enabled by default
|
||||
lm += "The capital of France is " + gen("capital", max_tokens=5)
|
||||
|
||||
# Process:
|
||||
# 1. Back up to token before " is "
|
||||
# 2. Regenerate " is" + "capital" together
|
||||
# 3. Result: "The capital of France is Paris" (correct)
|
||||
```
|
||||
|
||||
### Token Healing Examples
|
||||
|
||||
#### Natural Continuations
|
||||
|
||||
```python
|
||||
# Before token healing
|
||||
lm += "The function name is get" + gen("rest")
|
||||
# Might generate: "The function name is get User" (space before User)
|
||||
|
||||
# With token healing
|
||||
lm += "The function name is get" + gen("rest")
|
||||
# Generates: "The function name is getUser" (correct camelCase)
|
||||
```
|
||||
|
||||
#### Code Generation
|
||||
|
||||
```python
|
||||
# Function name completion
|
||||
lm += "def calculate_" + gen("rest", stop="(")
|
||||
# Token healing ensures smooth connection: "calculate_total"
|
||||
|
||||
# Variable name completion
|
||||
lm += "my_" + gen("var_name", regex=r"[a-z_]+")
|
||||
# Token healing ensures: "my_variable_name" (not "my_ variable_name")
|
||||
```
|
||||
|
||||
#### Domain-Specific Terms
|
||||
|
||||
```python
|
||||
# Medical terms
|
||||
lm += "The patient has hyper" + gen("condition")
|
||||
# Token healing helps: "hypertension" (not "hyper tension")
|
||||
|
||||
# Technical terms
|
||||
lm += "Using micro" + gen("tech")
|
||||
# Token healing helps: "microservices" (not "micro services")
|
||||
```
|
||||
|
||||
### Disabling Token Healing
|
||||
|
||||
```python
|
||||
# Disable token healing if needed (rare)
|
||||
lm += gen("text", token_healing=False)
|
||||
```
|
||||
|
||||
## Selection Constraints
|
||||
|
||||
### Basic Selection
|
||||
|
||||
```python
|
||||
from guidance import models, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Simple selection
|
||||
lm += "Status: " + select(["active", "inactive", "pending"], name="status")
|
||||
|
||||
# Boolean selection
|
||||
lm += "Approved: " + select(["Yes", "No"], name="approved")
|
||||
|
||||
# Multiple choice
|
||||
lm += "Answer: " + select(
|
||||
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
|
||||
name="answer"
|
||||
)
|
||||
```
|
||||
|
||||
### Conditional Selection
|
||||
|
||||
```python
|
||||
from guidance import models, select, gen, guidance
|
||||
|
||||
@guidance
|
||||
def conditional_fields(lm):
|
||||
"""Generate fields conditionally based on type."""
|
||||
lm += "Type: " + select(["person", "company"], name="type")
|
||||
|
||||
if lm["type"] == "person":
|
||||
lm += "\nName: " + gen("name", regex=r"[A-Za-z ]+")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
|
||||
else:
|
||||
lm += "\nCompany Name: " + gen("company", regex=r"[A-Za-z ]+")
|
||||
lm += "\nEmployees: " + gen("employees", regex=r"[0-9]+")
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Repeated Selection
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def multiple_selections(lm):
|
||||
"""Select multiple items."""
|
||||
lm += "Select 3 colors:\n"
|
||||
|
||||
colors = ["red", "blue", "green", "yellow", "purple"]
|
||||
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + select(colors, name=f"color_{i}") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Complex Patterns
|
||||
|
||||
### Pattern 1: Structured Forms
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def user_form(lm):
|
||||
"""Generate structured user form."""
|
||||
lm += "=== User Registration ===\n\n"
|
||||
|
||||
# Name (alphabetic only)
|
||||
lm += "Full Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Age (numeric)
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
|
||||
|
||||
# Email (validated format)
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
# Phone (US format)
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}") + "\n"
|
||||
|
||||
# Account type (selection)
|
||||
lm += "Account Type: " + select(
|
||||
["Standard", "Premium", "Enterprise"],
|
||||
name="account_type"
|
||||
) + "\n"
|
||||
|
||||
# Active status (boolean)
|
||||
lm += "Active: " + select(["Yes", "No"], name="active") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 2: Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract multiple entities with constraints."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
# Person name (alphabetic)
|
||||
lm += "Person: " + gen("person", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Organization (alphanumeric with spaces)
|
||||
lm += "Organization: " + gen(
|
||||
"organization",
|
||||
regex=r"[A-Za-z0-9 ]+",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
# Date (YYYY-MM-DD format)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}") + "\n"
|
||||
|
||||
# Location (alphabetic with spaces)
|
||||
lm += "Location: " + gen("location", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Amount (currency)
|
||||
lm += "Amount: $" + gen("amount", regex=r"[0-9,]+\.[0-9]{2}") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 3: Code Generation
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_python_function(lm):
|
||||
"""Generate Python function with constraints."""
|
||||
# Function name (valid Python identifier)
|
||||
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
|
||||
|
||||
# Parameter name
|
||||
lm += gen("param", regex=r"[a-z_][a-z0-9_]*") + "):\n"
|
||||
|
||||
# Docstring
|
||||
lm += ' """' + gen("docstring", stop='"""', max_tokens=50) + '"""\n'
|
||||
|
||||
# Function body (constrained to valid Python)
|
||||
lm += " return " + gen("return_value", stop="\n") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 4: Hierarchical Data
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def org_chart(lm):
|
||||
"""Generate organizational chart."""
|
||||
lm += "Company: " + gen("company", regex=r"[A-Za-z ]+") + "\n\n"
|
||||
|
||||
# CEO
|
||||
lm += "CEO: " + gen("ceo", regex=r"[A-Za-z ]+") + "\n"
|
||||
|
||||
# Departments
|
||||
for dept in ["Engineering", "Sales", "Marketing"]:
|
||||
lm += f"\n{dept} Department:\n"
|
||||
lm += " Head: " + gen(f"{dept.lower()}_head", regex=r"[A-Za-z ]+") + "\n"
|
||||
lm += " Size: " + gen(f"{dept.lower()}_size", regex=r"[0-9]+") + " employees\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Best Practices
|
||||
|
||||
#### 1. Use Specific Patterns
|
||||
|
||||
```python
|
||||
# ✅ Good: Specific pattern
|
||||
lm += gen("age", regex=r"[0-9]{1,3}") # Fast
|
||||
|
||||
# ❌ Bad: Overly broad pattern
|
||||
lm += gen("age", regex=r"[0-9]+") # Slower
|
||||
```
|
||||
|
||||
#### 2. Limit Max Tokens
|
||||
|
||||
```python
|
||||
# ✅ Good: Reasonable limit
|
||||
lm += gen("name", max_tokens=30)
|
||||
|
||||
# ❌ Bad: No limit
|
||||
lm += gen("name") # May generate forever
|
||||
```
|
||||
|
||||
#### 3. Use stop Sequences
|
||||
|
||||
```python
|
||||
# ✅ Good: Stop at newline
|
||||
lm += gen("line", stop="\n")
|
||||
|
||||
# ❌ Bad: Rely on max_tokens
|
||||
lm += gen("line", max_tokens=100)
|
||||
```
|
||||
|
||||
#### 4. Cache Compiled Grammars
|
||||
|
||||
```python
|
||||
# Grammars are cached automatically after first use
|
||||
# No manual caching needed
|
||||
@guidance
|
||||
def reusable_pattern(lm):
|
||||
"""This grammar is compiled once and cached."""
|
||||
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
return lm
|
||||
|
||||
# First call: compiles grammar
|
||||
lm = reusable_pattern(lm)
|
||||
|
||||
# Subsequent calls: uses cached grammar (fast)
|
||||
lm = reusable_pattern(lm)
|
||||
```
|
||||
|
||||
#### 5. Avoid Overlapping Constraints
|
||||
|
||||
```python
|
||||
# ✅ Good: Clear constraints
|
||||
lm += gen("age", regex=r"[0-9]+", max_tokens=3)
|
||||
|
||||
# ❌ Bad: Conflicting constraints
|
||||
lm += gen("age", regex=r"[0-9]{2}", max_tokens=10) # max_tokens unnecessary
|
||||
```
|
||||
|
||||
### Performance Benchmarks
|
||||
|
||||
**Regex vs Free Generation:**
|
||||
- Simple regex (digits): ~1.2x slower than free gen
|
||||
- Complex regex (email): ~1.5x slower than free gen
|
||||
- Grammar-based: ~2x slower than free gen
|
||||
|
||||
**But:**
|
||||
- 100% valid outputs (vs ~70% with free gen + validation)
|
||||
- No retry loops needed
|
||||
- Overall faster end-to-end for structured outputs
|
||||
|
||||
**Optimization Tips:**
|
||||
- Use regex for critical fields only
|
||||
- Use `select()` for small fixed sets (fastest)
|
||||
- Use `stop` sequences when possible (faster than max_tokens)
|
||||
- Cache compiled grammars by reusing functions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Token Healing Paper**: https://arxiv.org/abs/2306.17648
|
||||
- **Guidance Docs**: https://guidance.readthedocs.io
|
||||
- **GitHub**: https://github.com/guidance-ai/guidance
|
||||
767
skills/mlops/guidance/references/examples.md
Normal file
767
skills/mlops/guidance/references/examples.md
Normal file
|
|
@ -0,0 +1,767 @@
|
|||
# Production-Ready Examples
|
||||
|
||||
Real-world examples of using Guidance for structured generation, agents, and workflows.
|
||||
|
||||
## Table of Contents
|
||||
- JSON Generation
|
||||
- Data Extraction
|
||||
- Classification Systems
|
||||
- Agent Systems
|
||||
- Multi-Step Workflows
|
||||
- Code Generation
|
||||
- Production Tips
|
||||
|
||||
## JSON Generation
|
||||
|
||||
### Basic JSON
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def generate_user(lm):
|
||||
"""Generate valid user JSON."""
|
||||
lm += "{\n"
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "email": ' + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + "\n"
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
# Use it
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm += "Generate a user profile:\n"
|
||||
lm = generate_user(lm)
|
||||
|
||||
print(lm)
|
||||
# Output: Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Nested JSON
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_order(lm):
|
||||
"""Generate nested order JSON."""
|
||||
lm += "{\n"
|
||||
|
||||
# Customer info
|
||||
lm += ' "customer": {\n'
|
||||
lm += ' "name": ' + gen("customer_name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "email": ' + gen(
|
||||
"customer_email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Order details
|
||||
lm += ' "order": {\n'
|
||||
lm += ' "id": ' + gen("order_id", regex=r'"ORD-[0-9]{6}"') + ",\n"
|
||||
lm += ' "date": ' + gen("order_date", regex=r'"\d{4}-\d{2}-\d{2}"') + ",\n"
|
||||
lm += ' "total": ' + gen("order_total", regex=r"[0-9]+\.[0-9]{2}") + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Status
|
||||
lm += ' "status": ' + gen(
|
||||
"status",
|
||||
regex=r'"(pending|processing|shipped|delivered)"'
|
||||
) + "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_order(lm)
|
||||
```
|
||||
|
||||
### JSON Array
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_user_list(lm, count=3):
|
||||
"""Generate JSON array of users."""
|
||||
lm += "[\n"
|
||||
|
||||
for i in range(count):
|
||||
lm += " {\n"
|
||||
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "active": ' + gen(f"active_{i}", regex=r"(true|false)") + "\n"
|
||||
lm += " }"
|
||||
if i < count - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "]"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_user_list(lm, count=5)
|
||||
```
|
||||
|
||||
### Dynamic JSON Schema
|
||||
|
||||
```python
|
||||
import json
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def json_from_schema(lm, schema):
|
||||
"""Generate JSON matching a schema."""
|
||||
lm += "{\n"
|
||||
|
||||
fields = list(schema["properties"].items())
|
||||
for i, (field_name, field_schema) in enumerate(fields):
|
||||
lm += f' "{field_name}": '
|
||||
|
||||
# Handle different types
|
||||
if field_schema["type"] == "string":
|
||||
if "pattern" in field_schema:
|
||||
lm += gen(field_name, regex=f'"{field_schema["pattern"]}"')
|
||||
else:
|
||||
lm += gen(field_name, regex=r'"[^"]+"')
|
||||
elif field_schema["type"] == "number":
|
||||
lm += gen(field_name, regex=r"[0-9]+(\.[0-9]+)?")
|
||||
elif field_schema["type"] == "integer":
|
||||
lm += gen(field_name, regex=r"[0-9]+")
|
||||
elif field_schema["type"] == "boolean":
|
||||
lm += gen(field_name, regex=r"(true|false)")
|
||||
|
||||
if i < len(fields) - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
# Define schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"score": {"type": "number"},
|
||||
"active": {"type": "boolean"}
|
||||
}
|
||||
}
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = json_from_schema(lm, schema)
|
||||
```
|
||||
|
||||
## Data Extraction
|
||||
|
||||
### Extract from Text
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance, system, user, assistant
|
||||
|
||||
@guidance
|
||||
def extract_person_info(lm, text):
|
||||
"""Extract structured info from text."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
with assistant():
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
|
||||
lm += "Occupation: " + gen("occupation", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = "John Smith is a 35-year-old software engineer. Contact: john@example.com"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You extract structured information from text."
|
||||
|
||||
with user():
|
||||
lm = extract_person_info(lm, text)
|
||||
|
||||
print(f"Name: {lm['name']}")
|
||||
print(f"Age: {lm['age']}")
|
||||
print(f"Occupation: {lm['occupation']}")
|
||||
print(f"Email: {lm['email']}")
|
||||
```
|
||||
|
||||
### Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract multiple entity types."""
|
||||
lm += f"Analyze: {text}\n\n"
|
||||
|
||||
# Person entities
|
||||
lm += "People:\n"
|
||||
for i in range(3): # Up to 3 people
|
||||
lm += f"- " + gen(f"person_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Organization entities
|
||||
lm += "\nOrganizations:\n"
|
||||
for i in range(2): # Up to 2 orgs
|
||||
lm += f"- " + gen(f"org_{i}", regex=r"[A-Za-z0-9 ]+", stop="\n") + "\n"
|
||||
|
||||
# Dates
|
||||
lm += "\nDates:\n"
|
||||
for i in range(2): # Up to 2 dates
|
||||
lm += f"- " + gen(f"date_{i}", regex=r"\d{4}-\d{2}-\d{2}", stop="\n") + "\n"
|
||||
|
||||
# Locations
|
||||
lm += "\nLocations:\n"
|
||||
for i in range(2): # Up to 2 locations
|
||||
lm += f"- " + gen(f"location_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = """
|
||||
Tim Cook and Satya Nadella met at Microsoft headquarters in Redmond on 2024-09-15
|
||||
to discuss the collaboration between Apple and Microsoft. The meeting continued
|
||||
in Cupertino on 2024-09-20.
|
||||
"""
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = extract_entities(lm, text)
|
||||
```
|
||||
|
||||
### Batch Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def batch_extract(lm, texts):
|
||||
"""Extract from multiple texts."""
|
||||
lm += "Batch Extraction Results:\n\n"
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
lm += f"=== Item {i+1} ===\n"
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen(f"name_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Sentiment: " + gen(
|
||||
f"sentiment_{i}",
|
||||
regex=r"(positive|negative|neutral)",
|
||||
stop="\n"
|
||||
) + "\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
texts = [
|
||||
"Alice is happy with the product",
|
||||
"Bob is disappointed with the service",
|
||||
"Carol has no strong feelings either way"
|
||||
]
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = batch_extract(lm, texts)
|
||||
```
|
||||
|
||||
## Classification Systems
|
||||
|
||||
### Sentiment Analysis
|
||||
|
||||
```python
|
||||
from guidance import models, select, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
text = "This product is absolutely amazing! Best purchase ever."
|
||||
|
||||
lm += f"Text: {text}\n\n"
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "negative", "neutral"],
|
||||
name="sentiment"
|
||||
)
|
||||
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]{1,3}") + "%\n"
|
||||
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=50)
|
||||
|
||||
print(f"Sentiment: {lm['sentiment']}")
|
||||
print(f"Confidence: {lm['confidence']}%")
|
||||
print(f"Reasoning: {lm['reasoning']}")
|
||||
```
|
||||
|
||||
### Multi-Label Classification
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def classify_article(lm, text):
|
||||
"""Classify article with multiple labels."""
|
||||
lm += f"Article: {text}\n\n"
|
||||
|
||||
# Primary category
|
||||
lm += "Primary Category: " + select(
|
||||
["Technology", "Business", "Science", "Politics", "Entertainment"],
|
||||
name="primary_category"
|
||||
) + "\n"
|
||||
|
||||
# Secondary categories (up to 3)
|
||||
lm += "\nSecondary Categories:\n"
|
||||
categories = ["Technology", "Business", "Science", "Politics", "Entertainment"]
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + select(categories, name=f"secondary_{i}") + "\n"
|
||||
|
||||
# Tags
|
||||
lm += "\nTags: " + gen("tags", stop="\n", max_tokens=50) + "\n"
|
||||
|
||||
# Target audience
|
||||
lm += "Target Audience: " + select(
|
||||
["General", "Expert", "Beginner"],
|
||||
name="audience"
|
||||
)
|
||||
|
||||
return lm
|
||||
|
||||
article = """
|
||||
Apple announced new AI features in iOS 18, leveraging machine learning to improve
|
||||
battery life and performance. The company's stock rose 5% following the announcement.
|
||||
"""
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = classify_article(lm, article)
|
||||
```
|
||||
|
||||
### Intent Classification
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def classify_intent(lm, message):
|
||||
"""Classify user intent."""
|
||||
lm += f"User Message: {message}\n\n"
|
||||
|
||||
# Intent
|
||||
lm += "Intent: " + select(
|
||||
["question", "complaint", "request", "feedback", "other"],
|
||||
name="intent"
|
||||
) + "\n"
|
||||
|
||||
# Urgency
|
||||
lm += "Urgency: " + select(
|
||||
["low", "medium", "high", "critical"],
|
||||
name="urgency"
|
||||
) + "\n"
|
||||
|
||||
# Department
|
||||
lm += "Route To: " + select(
|
||||
["support", "sales", "billing", "technical"],
|
||||
name="department"
|
||||
) + "\n"
|
||||
|
||||
# Sentiment
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "neutral", "negative"],
|
||||
name="sentiment"
|
||||
)
|
||||
|
||||
return lm
|
||||
|
||||
message = "My account was charged twice for the same order. Need help ASAP!"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = classify_intent(lm, message)
|
||||
|
||||
print(f"Intent: {lm['intent']}")
|
||||
print(f"Urgency: {lm['urgency']}")
|
||||
print(f"Department: {lm['department']}")
|
||||
```
|
||||
|
||||
## Agent Systems
|
||||
|
||||
### ReAct Agent
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select, guidance
|
||||
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question, tools, max_rounds=5):
|
||||
"""ReAct agent with tool use."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for round in range(max_rounds):
|
||||
# Thought
|
||||
lm += f"Thought {round+1}: " + gen("thought", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Action selection
|
||||
lm += "Action: " + select(
|
||||
list(tools.keys()) + ["answer"],
|
||||
name="action"
|
||||
)
|
||||
|
||||
if lm["action"] == "answer":
|
||||
lm += "\n\nFinal Answer: " + gen("answer", max_tokens=200)
|
||||
break
|
||||
|
||||
# Action input
|
||||
lm += "\nAction Input: " + gen("action_input", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Execute tool
|
||||
if lm["action"] in tools:
|
||||
try:
|
||||
result = tools[lm["action"]](lm["action_input"])
|
||||
lm += f"Observation: {result}\n\n"
|
||||
except Exception as e:
|
||||
lm += f"Observation: Error - {str(e)}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
# Define tools
|
||||
tools = {
|
||||
"calculator": lambda expr: eval(expr),
|
||||
"search": lambda query: f"Search results for '{query}': [Mock results]",
|
||||
"weather": lambda city: f"Weather in {city}: Sunny, 72°F"
|
||||
}
|
||||
|
||||
# Use agent
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = react_agent(lm, "What is (25 * 4) + 10?", tools)
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Multi-Agent System
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def coordinator_agent(lm, task):
|
||||
"""Coordinator that delegates to specialists."""
|
||||
lm += f"Task: {task}\n\n"
|
||||
|
||||
# Determine which specialist to use
|
||||
lm += "Specialist: " + select(
|
||||
["researcher", "writer", "coder", "analyst"],
|
||||
name="specialist"
|
||||
) + "\n"
|
||||
|
||||
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def researcher_agent(lm, query):
|
||||
"""Research specialist."""
|
||||
lm += f"Research Query: {query}\n\n"
|
||||
lm += "Findings:\n"
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + gen(f"finding_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def writer_agent(lm, topic):
|
||||
"""Writing specialist."""
|
||||
lm += f"Topic: {topic}\n\n"
|
||||
lm += "Title: " + gen("title", stop="\n", max_tokens=50) + "\n"
|
||||
lm += "Content:\n" + gen("content", max_tokens=500)
|
||||
return lm
|
||||
|
||||
# Coordination workflow
|
||||
task = "Write an article about AI safety"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = coordinator_agent(lm, task)
|
||||
|
||||
specialist = lm["specialist"]
|
||||
if specialist == "researcher":
|
||||
lm = researcher_agent(lm, task)
|
||||
elif specialist == "writer":
|
||||
lm = writer_agent(lm, task)
|
||||
```
|
||||
|
||||
### Tool Use with Validation
|
||||
|
||||
```python
|
||||
@guidance(stateless=False)
|
||||
def validated_tool_agent(lm, question):
|
||||
"""Agent with validated tool calls."""
|
||||
tools = {
|
||||
"add": lambda a, b: float(a) + float(b),
|
||||
"multiply": lambda a, b: float(a) * float(b),
|
||||
"divide": lambda a, b: float(a) / float(b) if float(b) != 0 else "Error: Division by zero"
|
||||
}
|
||||
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for i in range(5):
|
||||
# Select tool
|
||||
lm += "Tool: " + select(list(tools.keys()) + ["done"], name="tool")
|
||||
|
||||
if lm["tool"] == "done":
|
||||
lm += "\nAnswer: " + gen("answer", max_tokens=100)
|
||||
break
|
||||
|
||||
# Get validated numeric arguments
|
||||
lm += "\nArg1: " + gen("arg1", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
|
||||
lm += "Arg2: " + gen("arg2", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
|
||||
|
||||
# Execute
|
||||
result = tools[lm["tool"]](lm["arg1"], lm["arg2"])
|
||||
lm += f"Result: {result}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = validated_tool_agent(lm, "What is (10 + 5) * 3?")
|
||||
```
|
||||
|
||||
## Multi-Step Workflows
|
||||
|
||||
### Chain of Thought
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def chain_of_thought(lm, question):
|
||||
"""Multi-step reasoning with CoT."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
# Generate reasoning steps
|
||||
lm += "Let me think step by step:\n\n"
|
||||
for i in range(4):
|
||||
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Final answer
|
||||
lm += "\nTherefore, the answer is: " + gen("answer", stop="\n", max_tokens=50)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = chain_of_thought(lm, "If a train travels 60 mph for 2.5 hours, how far does it go?")
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Self-Consistency
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def self_consistency(lm, question, num_samples=3):
|
||||
"""Generate multiple reasoning paths and aggregate."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
answers = []
|
||||
for i in range(num_samples):
|
||||
lm += f"=== Attempt {i+1} ===\n"
|
||||
lm += "Reasoning: " + gen(f"reasoning_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
lm += "Answer: " + gen(f"answer_{i}", stop="\n", max_tokens=50) + "\n\n"
|
||||
answers.append(lm[f"answer_{i}"])
|
||||
|
||||
# Aggregate (simple majority vote)
|
||||
from collections import Counter
|
||||
most_common = Counter(answers).most_common(1)[0][0]
|
||||
|
||||
lm += f"Final Answer (by majority): {most_common}\n"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = self_consistency(lm, "What is 15% of 200?")
|
||||
```
|
||||
|
||||
### Planning and Execution
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def plan_and_execute(lm, goal):
|
||||
"""Plan tasks then execute them."""
|
||||
lm += f"Goal: {goal}\n\n"
|
||||
|
||||
# Planning phase
|
||||
lm += "Plan:\n"
|
||||
num_steps = 4
|
||||
for i in range(num_steps):
|
||||
lm += f"{i+1}. " + gen(f"plan_step_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Execution phase
|
||||
lm += "\nExecution:\n\n"
|
||||
for i in range(num_steps):
|
||||
lm += f"Step {i+1}: {lm[f'plan_step_{i}']}\n"
|
||||
lm += "Status: " + select(["completed", "in-progress", "blocked"], name=f"status_{i}") + "\n"
|
||||
lm += "Result: " + gen(f"result_{i}", stop="\n", max_tokens=150) + "\n\n"
|
||||
|
||||
# Summary
|
||||
lm += "Summary: " + gen("summary", max_tokens=200)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = plan_and_execute(lm, "Build a REST API for a blog platform")
|
||||
```
|
||||
|
||||
## Code Generation
|
||||
|
||||
### Python Function
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_python_function(lm, description):
|
||||
"""Generate Python function from description."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
|
||||
# Function signature
|
||||
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
|
||||
lm += gen("params", regex=r"[a-z_][a-z0-9_]*(, [a-z_][a-z0-9_]*)*") + "):\n"
|
||||
|
||||
# Docstring
|
||||
lm += ' """' + gen("docstring", stop='"""', max_tokens=100) + '"""\n'
|
||||
|
||||
# Function body
|
||||
lm += " " + gen("body", stop="\n", max_tokens=200) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_python_function(lm, "Check if a number is prime")
|
||||
|
||||
print(lm)
|
||||
```
|
||||
|
||||
### SQL Query
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_sql(lm, description):
|
||||
"""Generate SQL query from description."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
lm += "SQL Query:\n"
|
||||
|
||||
# SELECT clause
|
||||
lm += "SELECT " + gen("select_clause", stop=" FROM", max_tokens=100)
|
||||
|
||||
# FROM clause
|
||||
lm += " FROM " + gen("from_clause", stop=" WHERE", max_tokens=50)
|
||||
|
||||
# WHERE clause (optional)
|
||||
lm += " WHERE " + gen("where_clause", stop=";", max_tokens=100) + ";"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_sql(lm, "Get all users who signed up in the last 30 days")
|
||||
```
|
||||
|
||||
### API Endpoint
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_api_endpoint(lm, description):
|
||||
"""Generate REST API endpoint."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
|
||||
# HTTP method
|
||||
lm += "Method: " + select(["GET", "POST", "PUT", "DELETE"], name="method") + "\n"
|
||||
|
||||
# Path
|
||||
lm += "Path: /" + gen("path", regex=r"[a-z0-9/-]+", stop="\n") + "\n"
|
||||
|
||||
# Request body (if POST/PUT)
|
||||
if lm["method"] in ["POST", "PUT"]:
|
||||
lm += "\nRequest Body:\n"
|
||||
lm += "{\n"
|
||||
lm += ' "field1": ' + gen("field1", regex=r'"[a-z_]+"') + ",\n"
|
||||
lm += ' "field2": ' + gen("field2", regex=r'"[a-z_]+"') + "\n"
|
||||
lm += "}\n"
|
||||
|
||||
# Response
|
||||
lm += "\nResponse (200 OK):\n"
|
||||
lm += "{\n"
|
||||
lm += ' "status": "success",\n'
|
||||
lm += ' "data": ' + gen("response_data", max_tokens=100) + "\n"
|
||||
lm += "}\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_api_endpoint(lm, "Create a new blog post")
|
||||
```
|
||||
|
||||
## Production Tips
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def safe_extraction(lm, text):
|
||||
"""Extract with fallback handling."""
|
||||
try:
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n", max_tokens=30)
|
||||
return lm
|
||||
except Exception as e:
|
||||
# Fallback to less strict extraction
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen("name", stop="\n", max_tokens=30)
|
||||
return lm
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def cached_generation(text):
|
||||
"""Cache LLM generations."""
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm += f"Analyze: {text}\n"
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
return lm["sentiment"]
|
||||
|
||||
# First call: hits LLM
|
||||
result1 = cached_generation("This is great!")
|
||||
|
||||
# Second call: returns cached result
|
||||
result2 = cached_generation("This is great!") # Instant!
|
||||
```
|
||||
|
||||
### Monitoring
|
||||
|
||||
```python
|
||||
import time
|
||||
|
||||
@guidance
|
||||
def monitored_generation(lm, text):
|
||||
"""Track generation metrics."""
|
||||
start_time = time.time()
|
||||
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Analysis: " + gen("analysis", max_tokens=100)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Log metrics
|
||||
print(f"Generation time: {elapsed:.2f}s")
|
||||
print(f"Output length: {len(lm['analysis'])} chars")
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
def batch_process(texts, batch_size=10):
|
||||
"""Process texts in batches."""
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
results = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
|
||||
for text in batch:
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "negative", "neutral"],
|
||||
name=f"sentiment_{i}"
|
||||
) + "\n\n"
|
||||
|
||||
results.extend([lm[f"sentiment_{i}"] for i in range(len(batch))])
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Guidance Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
|
||||
- **Guidance Docs**: https://guidance.readthedocs.io
|
||||
- **Community Examples**: https://github.com/guidance-ai/guidance/discussions
|
||||
307
skills/mlops/llava/SKILL.md
Normal file
307
skills/mlops/llava/SKILL.md
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
---
|
||||
name: llava
|
||||
description: Large Language and Vision Assistant. Enables visual instruction tuning and image-based conversations. Combines CLIP vision encoder with Vicuna/LLaMA language models. Supports multi-turn image chat, visual question answering, and instruction following. Use for vision-language chatbots or image understanding tasks. Best for conversational image analysis.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers, torch, pillow]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [LLaVA, Vision-Language, Multimodal, Visual Question Answering, Image Chat, CLIP, Vicuna, Conversational AI, Instruction Tuning, VQA]
|
||||
|
||||
---
|
||||
|
||||
# LLaVA - Large Language and Vision Assistant
|
||||
|
||||
Open-source vision-language model for conversational image understanding.
|
||||
|
||||
## When to use LLaVA
|
||||
|
||||
**Use when:**
|
||||
- Building vision-language chatbots
|
||||
- Visual question answering (VQA)
|
||||
- Image description and captioning
|
||||
- Multi-turn image conversations
|
||||
- Visual instruction following
|
||||
- Document understanding with images
|
||||
|
||||
**Metrics**:
|
||||
- **23,000+ GitHub stars**
|
||||
- GPT-4V level capabilities (targeted)
|
||||
- Apache 2.0 License
|
||||
- Multiple model sizes (7B-34B params)
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **GPT-4V**: Highest quality, API-based
|
||||
- **CLIP**: Simple zero-shot classification
|
||||
- **BLIP-2**: Better for captioning only
|
||||
- **Flamingo**: Research, not open-source
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/haotian-liu/LLaVA
|
||||
cd LLaVA
|
||||
|
||||
# Install
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
from llava.model.builder import load_pretrained_model
|
||||
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
|
||||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
||||
from llava.conversation import conv_templates
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
# Load model
|
||||
model_path = "liuhaotian/llava-v1.5-7b"
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path=model_path,
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path(model_path)
|
||||
)
|
||||
|
||||
# Load image
|
||||
image = Image.open("image.jpg")
|
||||
image_tensor = process_images([image], image_processor, model.config)
|
||||
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
||||
|
||||
# Create conversation
|
||||
conv = conv_templates["llava_v1"].copy()
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
# Generate response
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=image_tensor,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
max_new_tokens=512
|
||||
)
|
||||
|
||||
response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Available models
|
||||
|
||||
| Model | Parameters | VRAM | Quality |
|
||||
|-------|------------|------|---------|
|
||||
| LLaVA-v1.5-7B | 7B | ~14 GB | Good |
|
||||
| LLaVA-v1.5-13B | 13B | ~28 GB | Better |
|
||||
| LLaVA-v1.6-34B | 34B | ~70 GB | Best |
|
||||
|
||||
```python
|
||||
# Load different models
|
||||
model_7b = "liuhaotian/llava-v1.5-7b"
|
||||
model_13b = "liuhaotian/llava-v1.5-13b"
|
||||
model_34b = "liuhaotian/llava-v1.6-34b"
|
||||
|
||||
# 4-bit quantization for lower VRAM
|
||||
load_4bit = True # Reduces VRAM by ~4×
|
||||
```
|
||||
|
||||
## CLI usage
|
||||
|
||||
```bash
|
||||
# Single image query
|
||||
python -m llava.serve.cli \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--image-file image.jpg \
|
||||
--query "What is in this image?"
|
||||
|
||||
# Multi-turn conversation
|
||||
python -m llava.serve.cli \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--image-file image.jpg
|
||||
# Then type questions interactively
|
||||
```
|
||||
|
||||
## Web UI (Gradio)
|
||||
|
||||
```bash
|
||||
# Launch Gradio interface
|
||||
python -m llava.serve.gradio_web_server \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--load-4bit # Optional: reduce VRAM
|
||||
|
||||
# Access at http://localhost:7860
|
||||
```
|
||||
|
||||
## Multi-turn conversations
|
||||
|
||||
```python
|
||||
# Initialize conversation
|
||||
conv = conv_templates["llava_v1"].copy()
|
||||
|
||||
# Turn 1
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response1 = generate(conv, model, image) # "A dog playing in a park"
|
||||
|
||||
# Turn 2
|
||||
conv.messages[-1][1] = response1 # Add previous response
|
||||
conv.append_message(conv.roles[0], "What breed is the dog?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response2 = generate(conv, model, image) # "Golden Retriever"
|
||||
|
||||
# Turn 3
|
||||
conv.messages[-1][1] = response2
|
||||
conv.append_message(conv.roles[0], "What time of day is it?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response3 = generate(conv, model, image)
|
||||
```
|
||||
|
||||
## Common tasks
|
||||
|
||||
### Image captioning
|
||||
|
||||
```python
|
||||
question = "Describe this image in detail."
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Visual question answering
|
||||
|
||||
```python
|
||||
question = "How many people are in the image?"
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Object detection (textual)
|
||||
|
||||
```python
|
||||
question = "List all the objects you can see in this image."
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Scene understanding
|
||||
|
||||
```python
|
||||
question = "What is happening in this scene?"
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Document understanding
|
||||
|
||||
```python
|
||||
question = "What is the main topic of this document?"
|
||||
response = ask(model, document_image, question)
|
||||
```
|
||||
|
||||
## Training custom model
|
||||
|
||||
```bash
|
||||
# Stage 1: Feature alignment (558K image-caption pairs)
|
||||
bash scripts/v1_5/pretrain.sh
|
||||
|
||||
# Stage 2: Visual instruction tuning (150K instruction data)
|
||||
bash scripts/v1_5/finetune.sh
|
||||
```
|
||||
|
||||
## Quantization (reduce VRAM)
|
||||
|
||||
```python
|
||||
# 4-bit quantization
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path="liuhaotian/llava-v1.5-13b",
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path("liuhaotian/llava-v1.5-13b"),
|
||||
load_4bit=True # Reduces VRAM ~4×
|
||||
)
|
||||
|
||||
# 8-bit quantization
|
||||
load_8bit=True # Reduces VRAM ~2×
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with 7B model** - Good quality, manageable VRAM
|
||||
2. **Use 4-bit quantization** - Reduces VRAM significantly
|
||||
3. **GPU required** - CPU inference extremely slow
|
||||
4. **Clear prompts** - Specific questions get better answers
|
||||
5. **Multi-turn conversations** - Maintain conversation context
|
||||
6. **Temperature 0.2-0.7** - Balance creativity/consistency
|
||||
7. **max_new_tokens 512-1024** - For detailed responses
|
||||
8. **Batch processing** - Process multiple images sequentially
|
||||
|
||||
## Performance
|
||||
|
||||
| Model | VRAM (FP16) | VRAM (4-bit) | Speed (tokens/s) |
|
||||
|-------|-------------|--------------|------------------|
|
||||
| 7B | ~14 GB | ~4 GB | ~20 |
|
||||
| 13B | ~28 GB | ~8 GB | ~12 |
|
||||
| 34B | ~70 GB | ~18 GB | ~5 |
|
||||
|
||||
*On A100 GPU*
|
||||
|
||||
## Benchmarks
|
||||
|
||||
LLaVA achieves competitive scores on:
|
||||
- **VQAv2**: 78.5%
|
||||
- **GQA**: 62.0%
|
||||
- **MM-Vet**: 35.4%
|
||||
- **MMBench**: 64.3%
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Hallucinations** - May describe things not in image
|
||||
2. **Spatial reasoning** - Struggles with precise locations
|
||||
3. **Small text** - Difficulty reading fine print
|
||||
4. **Object counting** - Imprecise for many objects
|
||||
5. **VRAM requirements** - Need powerful GPU
|
||||
6. **Inference speed** - Slower than CLIP
|
||||
|
||||
## Integration with frameworks
|
||||
|
||||
### LangChain
|
||||
|
||||
```python
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
class LLaVALLM(LLM):
|
||||
def _call(self, prompt, stop=None):
|
||||
# Custom LLaVA inference
|
||||
return response
|
||||
|
||||
llm = LLaVALLM()
|
||||
```
|
||||
|
||||
### Gradio App
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def chat(image, text, history):
|
||||
response = ask_llava(model, image, text)
|
||||
return response
|
||||
|
||||
demo = gr.ChatInterface(
|
||||
chat,
|
||||
additional_inputs=[gr.Image(type="pil")],
|
||||
title="LLaVA Chat"
|
||||
)
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/haotian-liu/LLaVA ⭐ 23,000+
|
||||
- **Paper**: https://arxiv.org/abs/2304.08485
|
||||
- **Demo**: https://llava.hliu.cc
|
||||
- **Models**: https://huggingface.co/liuhaotian
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
197
skills/mlops/llava/references/training.md
Normal file
197
skills/mlops/llava/references/training.md
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
# LLaVA Training Guide
|
||||
|
||||
Guide to training and fine-tuning LLaVA models.
|
||||
|
||||
## Training stages
|
||||
|
||||
### Stage 1: Feature alignment (Pretraining)
|
||||
|
||||
**Purpose**: Align vision encoder with language model
|
||||
|
||||
**Data**: 558K image-caption pairs (CC3M subset)
|
||||
|
||||
```bash
|
||||
# Download pretrained projector or train from scratch
|
||||
bash scripts/v1_5/pretrain.sh
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
- Base model: Vicuna-7B or LLaMA-2-7B
|
||||
- Vision encoder: CLIP ViT-L/14
|
||||
- Training time: ~20 hours on 8× A100
|
||||
|
||||
### Stage 2: Visual instruction tuning
|
||||
|
||||
**Purpose**: Teach model to follow visual instructions
|
||||
|
||||
**Data**: 150K GPT-generated multimodal instruction data
|
||||
|
||||
```bash
|
||||
# Fine-tune with instruction data
|
||||
bash scripts/v1_5/finetune.sh
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
- Epochs: 1
|
||||
- Batch size: 128 (across 8 GPUs)
|
||||
- Learning rate: 2e-5
|
||||
- Training time: ~24 hours on 8× A100
|
||||
|
||||
## Data format
|
||||
|
||||
### Instruction data format
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "001",
|
||||
"image": "path/to/image.jpg",
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "<image>\nWhat is in this image?"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "The image shows a dog playing in a park."
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "What breed is the dog?"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "It appears to be a Golden Retriever."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Fine-tuning on custom data
|
||||
|
||||
### Prepare your data
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Create instruction data
|
||||
data = []
|
||||
for image_path, qa_pairs in your_dataset:
|
||||
conversations = []
|
||||
for q, a in qa_pairs:
|
||||
conversations.append({"from": "human", "value": f"<image>\n{q}"})
|
||||
conversations.append({"from": "gpt", "value": a})
|
||||
|
||||
data.append({
|
||||
"id": str(len(data)),
|
||||
"image": image_path,
|
||||
"conversations": conversations
|
||||
})
|
||||
|
||||
# Save
|
||||
with open("custom_data.json", "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
```
|
||||
|
||||
### Fine-tune script
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
# Set paths
|
||||
DATA_PATH="custom_data.json"
|
||||
IMAGE_FOLDER="path/to/images"
|
||||
MODEL_PATH="liuhaotian/llava-v1.5-7b"
|
||||
OUTPUT_DIR="./checkpoints/llava-custom"
|
||||
|
||||
# Fine-tune
|
||||
deepspeed llava/train/train_mem.py \
|
||||
--deepspeed ./scripts/zero2.json \
|
||||
--model_name_or_path $MODEL_PATH \
|
||||
--version v1 \
|
||||
--data_path $DATA_PATH \
|
||||
--image_folder $IMAGE_FOLDER \
|
||||
--vision_tower openai/clip-vit-large-patch14-336 \
|
||||
--mm_projector_type mlp2x_gelu \
|
||||
--mm_vision_select_layer -2 \
|
||||
--mm_use_im_start_end False \
|
||||
--mm_use_im_patch_token False \
|
||||
--image_aspect_ratio pad \
|
||||
--group_by_modality_length True \
|
||||
--bf16 True \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--num_train_epochs 1 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--evaluation_strategy "no" \
|
||||
--save_strategy "steps" \
|
||||
--save_steps 50000 \
|
||||
--save_total_limit 1 \
|
||||
--learning_rate 2e-5 \
|
||||
--weight_decay 0. \
|
||||
--warmup_ratio 0.03 \
|
||||
--lr_scheduler_type "cosine" \
|
||||
--logging_steps 1 \
|
||||
--tf32 True \
|
||||
--model_max_length 2048 \
|
||||
--gradient_checkpointing True \
|
||||
--dataloader_num_workers 4 \
|
||||
--lazy_preprocess True \
|
||||
--report_to wandb
|
||||
```
|
||||
|
||||
## LoRA fine-tuning (memory efficient)
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
# LoRA config
|
||||
lora_config = LoraConfig(
|
||||
r=8, # LoRA rank
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
|
||||
# Train with much lower memory
|
||||
```
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
### Full fine-tuning
|
||||
|
||||
- **7B model**: 8× A100 (40GB)
|
||||
- **13B model**: 8× A100 (80GB)
|
||||
- **Training time**: 20-48 hours
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
- **7B model**: 1× A100 (40GB)
|
||||
- **13B model**: 2× A100 (40GB)
|
||||
- **Training time**: 10-24 hours
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with pretrained** - Don't train from scratch
|
||||
2. **Use LoRA for efficiency** - 10× less memory
|
||||
3. **Quality over quantity** - 1K high-quality > 10K low-quality
|
||||
4. **Multi-turn conversations** - More engaging than single Q&A
|
||||
5. **Diverse images** - Cover different scenarios
|
||||
6. **Clear instructions** - Specific questions get better answers
|
||||
7. **Monitor loss** - Should decrease smoothly
|
||||
8. **Save checkpoints** - Training can fail
|
||||
9. **Test regularly** - Validate on held-out set
|
||||
10. **Use DeepSpeed** - For multi-GPU training
|
||||
|
||||
## Resources
|
||||
|
||||
- **Training script**: https://github.com/haotian-liu/LLaVA/tree/main/scripts
|
||||
- **Data format**: https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md
|
||||
- **Paper**: https://arxiv.org/abs/2304.08485
|
||||
386
skills/mlops/nemo-curator/SKILL.md
Normal file
386
skills/mlops/nemo-curator/SKILL.md
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
---
|
||||
name: nemo-curator
|
||||
description: GPU-accelerated data curation for LLM training. Supports text/image/video/audio. Features fuzzy deduplication (16× faster), quality filtering (30+ heuristics), semantic deduplication, PII redaction, NSFW detection. Scales across GPUs with RAPIDS. Use for preparing high-quality training datasets, cleaning web data, or deduplicating large corpora.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [nemo-curator, cudf, dask, rapids]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Data Processing, NeMo Curator, Data Curation, GPU Acceleration, Deduplication, Quality Filtering, NVIDIA, RAPIDS, PII Redaction, Multimodal, LLM Training Data]
|
||||
|
||||
---
|
||||
|
||||
# NeMo Curator - GPU-Accelerated Data Curation
|
||||
|
||||
NVIDIA's toolkit for preparing high-quality training data for LLMs.
|
||||
|
||||
## When to use NeMo Curator
|
||||
|
||||
**Use NeMo Curator when:**
|
||||
- Preparing LLM training data from web scrapes (Common Crawl)
|
||||
- Need fast deduplication (16× faster than CPU)
|
||||
- Curating multi-modal datasets (text, images, video, audio)
|
||||
- Filtering low-quality or toxic content
|
||||
- Scaling data processing across GPU cluster
|
||||
|
||||
**Performance**:
|
||||
- **16× faster** fuzzy deduplication (8TB RedPajama v2)
|
||||
- **40% lower TCO** vs CPU alternatives
|
||||
- **Near-linear scaling** across GPU nodes
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **datatrove**: CPU-based, open-source data processing
|
||||
- **dolma**: Allen AI's data toolkit
|
||||
- **Ray Data**: General ML data processing (no curation focus)
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Text curation (CUDA 12)
|
||||
uv pip install "nemo-curator[text_cuda12]"
|
||||
|
||||
# All modalities
|
||||
uv pip install "nemo-curator[all_cuda12]"
|
||||
|
||||
# CPU-only (slower)
|
||||
uv pip install "nemo-curator[cpu]"
|
||||
```
|
||||
|
||||
### Basic text curation pipeline
|
||||
|
||||
```python
|
||||
from nemo_curator import ScoreFilter, Modify
|
||||
from nemo_curator.datasets import DocumentDataset
|
||||
import pandas as pd
|
||||
|
||||
# Load data
|
||||
df = pd.DataFrame({"text": ["Good document", "Bad doc", "Excellent text"]})
|
||||
dataset = DocumentDataset(df)
|
||||
|
||||
# Quality filtering
|
||||
def quality_score(doc):
|
||||
return len(doc["text"].split()) > 5 # Filter short docs
|
||||
|
||||
filtered = ScoreFilter(quality_score)(dataset)
|
||||
|
||||
# Deduplication
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
deduped = ExactDuplicates()(filtered)
|
||||
|
||||
# Save
|
||||
deduped.to_parquet("curated_data/")
|
||||
```
|
||||
|
||||
## Data curation pipeline
|
||||
|
||||
### Stage 1: Quality filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import (
|
||||
WordCountFilter,
|
||||
RepeatedLinesFilter,
|
||||
UrlRatioFilter,
|
||||
NonAlphaNumericFilter
|
||||
)
|
||||
|
||||
# Apply 30+ heuristic filters
|
||||
from nemo_curator import ScoreFilter
|
||||
|
||||
# Word count filter
|
||||
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
|
||||
|
||||
# Remove repetitive content
|
||||
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
|
||||
|
||||
# URL ratio filter
|
||||
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
|
||||
```
|
||||
|
||||
### Stage 2: Deduplication
|
||||
|
||||
**Exact deduplication**:
|
||||
```python
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
|
||||
# Remove exact duplicates
|
||||
deduped = ExactDuplicates(id_field="id", text_field="text")(dataset)
|
||||
```
|
||||
|
||||
**Fuzzy deduplication** (16× faster on GPU):
|
||||
```python
|
||||
from nemo_curator.modules import FuzzyDuplicates
|
||||
|
||||
# MinHash + LSH deduplication
|
||||
fuzzy_dedup = FuzzyDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
num_hashes=260, # MinHash parameters
|
||||
num_buckets=20,
|
||||
hash_method="md5"
|
||||
)
|
||||
|
||||
deduped = fuzzy_dedup(dataset)
|
||||
```
|
||||
|
||||
**Semantic deduplication**:
|
||||
```python
|
||||
from nemo_curator.modules import SemanticDuplicates
|
||||
|
||||
# Embedding-based deduplication
|
||||
semantic_dedup = SemanticDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
threshold=0.8 # Cosine similarity threshold
|
||||
)
|
||||
|
||||
deduped = semantic_dedup(dataset)
|
||||
```
|
||||
|
||||
### Stage 3: PII redaction
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import Modify
|
||||
from nemo_curator.modifiers import PIIRedactor
|
||||
|
||||
# Redact personally identifiable information
|
||||
pii_redactor = PIIRedactor(
|
||||
supported_entities=["EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON", "LOCATION"],
|
||||
anonymize_action="replace" # or "redact"
|
||||
)
|
||||
|
||||
redacted = Modify(pii_redactor)(dataset)
|
||||
```
|
||||
|
||||
### Stage 4: Classifier filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import QualityClassifier
|
||||
|
||||
# Quality classification
|
||||
quality_clf = QualityClassifier(
|
||||
model_path="nvidia/quality-classifier-deberta",
|
||||
batch_size=256,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Filter low-quality documents
|
||||
high_quality = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
### GPU vs CPU performance
|
||||
|
||||
| Operation | CPU (16 cores) | GPU (A100) | Speedup |
|
||||
|-----------|----------------|------------|---------|
|
||||
| Fuzzy dedup (8TB) | 120 hours | 7.5 hours | 16× |
|
||||
| Exact dedup (1TB) | 8 hours | 0.5 hours | 16× |
|
||||
| Quality filtering | 2 hours | 0.2 hours | 10× |
|
||||
|
||||
### Multi-GPU scaling
|
||||
|
||||
```python
|
||||
from nemo_curator import get_client
|
||||
import dask_cuda
|
||||
|
||||
# Initialize GPU cluster
|
||||
client = get_client(cluster_type="gpu", n_workers=8)
|
||||
|
||||
# Process with 8 GPUs
|
||||
deduped = FuzzyDuplicates(...)(dataset)
|
||||
```
|
||||
|
||||
## Multi-modal curation
|
||||
|
||||
### Image curation
|
||||
|
||||
```python
|
||||
from nemo_curator.image import (
|
||||
AestheticFilter,
|
||||
NSFWFilter,
|
||||
CLIPEmbedder
|
||||
)
|
||||
|
||||
# Aesthetic scoring
|
||||
aesthetic_filter = AestheticFilter(threshold=5.0)
|
||||
filtered_images = aesthetic_filter(image_dataset)
|
||||
|
||||
# NSFW detection
|
||||
nsfw_filter = NSFWFilter(threshold=0.9)
|
||||
safe_images = nsfw_filter(filtered_images)
|
||||
|
||||
# Generate CLIP embeddings
|
||||
clip_embedder = CLIPEmbedder(model="openai/clip-vit-base-patch32")
|
||||
image_embeddings = clip_embedder(safe_images)
|
||||
```
|
||||
|
||||
### Video curation
|
||||
|
||||
```python
|
||||
from nemo_curator.video import (
|
||||
SceneDetector,
|
||||
ClipExtractor,
|
||||
InternVideo2Embedder
|
||||
)
|
||||
|
||||
# Detect scenes
|
||||
scene_detector = SceneDetector(threshold=27.0)
|
||||
scenes = scene_detector(video_dataset)
|
||||
|
||||
# Extract clips
|
||||
clip_extractor = ClipExtractor(min_duration=2.0, max_duration=10.0)
|
||||
clips = clip_extractor(scenes)
|
||||
|
||||
# Generate embeddings
|
||||
video_embedder = InternVideo2Embedder()
|
||||
video_embeddings = video_embedder(clips)
|
||||
```
|
||||
|
||||
### Audio curation
|
||||
|
||||
```python
|
||||
from nemo_curator.audio import (
|
||||
ASRInference,
|
||||
WERFilter,
|
||||
DurationFilter
|
||||
)
|
||||
|
||||
# ASR transcription
|
||||
asr = ASRInference(model="nvidia/stt_en_fastconformer_hybrid_large_pc")
|
||||
transcribed = asr(audio_dataset)
|
||||
|
||||
# Filter by WER (word error rate)
|
||||
wer_filter = WERFilter(max_wer=0.3)
|
||||
high_quality_audio = wer_filter(transcribed)
|
||||
|
||||
# Duration filtering
|
||||
duration_filter = DurationFilter(min_duration=1.0, max_duration=30.0)
|
||||
filtered_audio = duration_filter(high_quality_audio)
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Web scrape curation (Common Crawl)
|
||||
|
||||
```python
|
||||
from nemo_curator import ScoreFilter, Modify
|
||||
from nemo_curator.filters import *
|
||||
from nemo_curator.modules import *
|
||||
from nemo_curator.datasets import DocumentDataset
|
||||
|
||||
# Load Common Crawl data
|
||||
dataset = DocumentDataset.read_parquet("common_crawl/*.parquet")
|
||||
|
||||
# Pipeline
|
||||
pipeline = [
|
||||
# 1. Quality filtering
|
||||
WordCountFilter(min_words=100, max_words=50000),
|
||||
RepeatedLinesFilter(max_repeated_line_fraction=0.2),
|
||||
SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3),
|
||||
UrlRatioFilter(max_url_ratio=0.3),
|
||||
|
||||
# 2. Language filtering
|
||||
LanguageIdentificationFilter(target_languages=["en"]),
|
||||
|
||||
# 3. Deduplication
|
||||
ExactDuplicates(id_field="id", text_field="text"),
|
||||
FuzzyDuplicates(id_field="id", text_field="text", num_hashes=260),
|
||||
|
||||
# 4. PII redaction
|
||||
PIIRedactor(),
|
||||
|
||||
# 5. NSFW filtering
|
||||
NSFWClassifier(threshold=0.8)
|
||||
]
|
||||
|
||||
# Execute
|
||||
for stage in pipeline:
|
||||
dataset = stage(dataset)
|
||||
|
||||
# Save
|
||||
dataset.to_parquet("curated_common_crawl/")
|
||||
```
|
||||
|
||||
### Distributed processing
|
||||
|
||||
```python
|
||||
from nemo_curator import get_client
|
||||
from dask_cuda import LocalCUDACluster
|
||||
|
||||
# Multi-GPU cluster
|
||||
cluster = LocalCUDACluster(n_workers=8)
|
||||
client = get_client(cluster=cluster)
|
||||
|
||||
# Process large dataset
|
||||
dataset = DocumentDataset.read_parquet("s3://large_dataset/*.parquet")
|
||||
deduped = FuzzyDuplicates(...)(dataset)
|
||||
|
||||
# Cleanup
|
||||
client.close()
|
||||
cluster.close()
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Fuzzy deduplication (8TB RedPajama v2)
|
||||
|
||||
- **CPU (256 cores)**: 120 hours
|
||||
- **GPU (8× A100)**: 7.5 hours
|
||||
- **Speedup**: 16×
|
||||
|
||||
### Exact deduplication (1TB)
|
||||
|
||||
- **CPU (64 cores)**: 8 hours
|
||||
- **GPU (4× A100)**: 0.5 hours
|
||||
- **Speedup**: 16×
|
||||
|
||||
### Quality filtering (100GB)
|
||||
|
||||
- **CPU (32 cores)**: 2 hours
|
||||
- **GPU (2× A100)**: 0.2 hours
|
||||
- **Speedup**: 10×
|
||||
|
||||
## Cost comparison
|
||||
|
||||
**CPU-based curation** (AWS c5.18xlarge × 10):
|
||||
- Cost: $3.60/hour × 10 = $36/hour
|
||||
- Time for 8TB: 120 hours
|
||||
- **Total**: $4,320
|
||||
|
||||
**GPU-based curation** (AWS p4d.24xlarge × 2):
|
||||
- Cost: $32.77/hour × 2 = $65.54/hour
|
||||
- Time for 8TB: 7.5 hours
|
||||
- **Total**: $491.55
|
||||
|
||||
**Savings**: 89% reduction ($3,828 saved)
|
||||
|
||||
## Supported data formats
|
||||
|
||||
- **Input**: Parquet, JSONL, CSV
|
||||
- **Output**: Parquet (recommended), JSONL
|
||||
- **WebDataset**: TAR archives for multi-modal
|
||||
|
||||
## Use cases
|
||||
|
||||
**Production deployments**:
|
||||
- NVIDIA used NeMo Curator to prepare Nemotron-4 training data
|
||||
- Open-source datasets curated: RedPajama v2, The Pile
|
||||
|
||||
## References
|
||||
|
||||
- **[Filtering Guide](references/filtering.md)** - 30+ quality filters, heuristics
|
||||
- **[Deduplication Guide](references/deduplication.md)** - Exact, fuzzy, semantic methods
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/NVIDIA/NeMo-Curator ⭐ 500+
|
||||
- **Docs**: https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/
|
||||
- **Version**: 0.4.0+
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
|
||||
87
skills/mlops/nemo-curator/references/deduplication.md
Normal file
87
skills/mlops/nemo-curator/references/deduplication.md
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
# Deduplication Guide
|
||||
|
||||
Complete guide to exact, fuzzy, and semantic deduplication.
|
||||
|
||||
## Exact deduplication
|
||||
|
||||
Remove documents with identical content.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import ExactDuplicates
|
||||
|
||||
# Exact deduplication
|
||||
exact_dedup = ExactDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
hash_method="md5" # or "sha256"
|
||||
)
|
||||
|
||||
deduped = exact_dedup(dataset)
|
||||
```
|
||||
|
||||
**Performance**: ~16× faster on GPU vs CPU
|
||||
|
||||
## Fuzzy deduplication
|
||||
|
||||
Remove near-duplicate documents using MinHash + LSH.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import FuzzyDuplicates
|
||||
|
||||
fuzzy_dedup = FuzzyDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
num_hashes=260, # MinHash permutations (more = accurate)
|
||||
num_buckets=20, # LSH buckets (more = faster, less recall)
|
||||
hash_method="md5",
|
||||
jaccard_threshold=0.8 # Similarity threshold
|
||||
)
|
||||
|
||||
deduped = fuzzy_dedup(dataset)
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `num_hashes`: 128-512 (default 260)
|
||||
- `num_buckets`: 10-50 (default 20)
|
||||
- `jaccard_threshold`: 0.7-0.9 (default 0.8)
|
||||
|
||||
**Performance**: 16× faster on 8TB dataset (120h → 7.5h)
|
||||
|
||||
## Semantic deduplication
|
||||
|
||||
Remove semantically similar documents using embeddings.
|
||||
|
||||
```python
|
||||
from nemo_curator.modules import SemanticDuplicates
|
||||
|
||||
semantic_dedup = SemanticDuplicates(
|
||||
id_field="id",
|
||||
text_field="text",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
embedding_batch_size=256,
|
||||
threshold=0.85, # Cosine similarity threshold
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
deduped = semantic_dedup(dataset)
|
||||
```
|
||||
|
||||
**Models**:
|
||||
- `all-MiniLM-L6-v2`: Fast, 384 dims
|
||||
- `all-mpnet-base-v2`: Better quality, 768 dims
|
||||
- Custom models supported
|
||||
|
||||
## Comparison
|
||||
|
||||
| Method | Speed | Recall | Use Case |
|
||||
|--------|-------|--------|----------|
|
||||
| Exact | Fastest | 100% | Exact matches only |
|
||||
| Fuzzy | Fast | ~95% | Near-duplicates (recommended) |
|
||||
| Semantic | Slow | ~90% | Paraphrases, rewrites |
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with exact dedup** - Remove obvious duplicates
|
||||
2. **Use fuzzy for large datasets** - Best speed/quality trade-off
|
||||
3. **Semantic for high-value data** - Expensive but thorough
|
||||
4. **GPU acceleration required** - 10-16× speedup
|
||||
102
skills/mlops/nemo-curator/references/filtering.md
Normal file
102
skills/mlops/nemo-curator/references/filtering.md
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# Quality Filtering Guide
|
||||
|
||||
Complete guide to NeMo Curator's 30+ quality filters.
|
||||
|
||||
## Text-based filters
|
||||
|
||||
### Word count
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import WordCountFilter
|
||||
|
||||
# Filter by word count
|
||||
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
|
||||
```
|
||||
|
||||
### Repeated content
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import RepeatedLinesFilter
|
||||
|
||||
# Remove documents with >30% repeated lines
|
||||
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
|
||||
```
|
||||
|
||||
### Symbol ratio
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import SymbolToWordRatioFilter
|
||||
|
||||
# Remove documents with too many symbols
|
||||
dataset = dataset.filter(SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3))
|
||||
```
|
||||
|
||||
### URL ratio
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import UrlRatioFilter
|
||||
|
||||
# Remove documents with many URLs
|
||||
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
|
||||
```
|
||||
|
||||
## Language filtering
|
||||
|
||||
```python
|
||||
from nemo_curator.filters import LanguageIdentificationFilter
|
||||
|
||||
# Keep only English documents
|
||||
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en"]))
|
||||
|
||||
# Multiple languages
|
||||
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en", "es", "fr"]))
|
||||
```
|
||||
|
||||
## Classifier-based filtering
|
||||
|
||||
### Quality classifier
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import QualityClassifier
|
||||
|
||||
quality_clf = QualityClassifier(
|
||||
model_path="nvidia/quality-classifier-deberta",
|
||||
batch_size=256,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Filter low-quality (threshold > 0.5 = high quality)
|
||||
dataset = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
|
||||
```
|
||||
|
||||
### NSFW classifier
|
||||
|
||||
```python
|
||||
from nemo_curator.classifiers import NSFWClassifier
|
||||
|
||||
nsfw_clf = NSFWClassifier(threshold=0.9, device="cuda")
|
||||
|
||||
# Remove NSFW content
|
||||
dataset = dataset.filter(lambda doc: nsfw_clf(doc["text"]) < 0.9)
|
||||
```
|
||||
|
||||
## Heuristic filters
|
||||
|
||||
Full list of 30+ filters:
|
||||
- WordCountFilter
|
||||
- RepeatedLinesFilter
|
||||
- UrlRatioFilter
|
||||
- SymbolToWordRatioFilter
|
||||
- NonAlphaNumericFilter
|
||||
- BulletsFilter
|
||||
- WhiteSpaceFilter
|
||||
- ParenthesesFilter
|
||||
- LongWordFilter
|
||||
- And 20+ more...
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Apply cheap filters first** - Word count before GPU classifiers
|
||||
2. **Tune thresholds on sample** - Test on 10k docs before full run
|
||||
3. **Use GPU classifiers sparingly** - Expensive but effective
|
||||
4. **Chain filters efficiently** - Order by cost (cheap → expensive)
|
||||
314
skills/mlops/obliteratus/SKILL.md
Normal file
314
skills/mlops/obliteratus/SKILL.md
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
---
|
||||
name: obliteratus
|
||||
description: Remove refusal behaviors from open-weight LLMs using OBLITERATUS — mechanistic interpretability techniques (diff-in-means, SVD, whitened SVD, SAE decomposition, etc.) to excise guardrails while preserving reasoning. 9 CLI methods (+ 4 Python-API-only), 15 analysis modules, 116 model presets across 5 compute tiers. Use when a user wants to uncensor, abliterate, or remove refusal from an LLM.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
dependencies: [obliteratus, torch, transformers, bitsandbytes, accelerate, safetensors]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Abliteration, Uncensoring, Refusal-Removal, LLM, Weight-Projection, SVD, Mechanistic-Interpretability, HuggingFace, Model-Surgery]
|
||||
|
||||
---
|
||||
|
||||
# OBLITERATUS Skill
|
||||
|
||||
Remove refusal behaviors (guardrails) from open-weight LLMs without retraining or fine-tuning. Uses mechanistic interpretability techniques — including diff-in-means, SVD, whitened SVD, SAE decomposition, Bayesian kernel projection, and more — to identify and surgically excise refusal directions from model weights while preserving reasoning capabilities.
|
||||
|
||||
**License warning:** OBLITERATUS is AGPL-3.0. NEVER import it as a Python library. Always invoke via CLI (`obliteratus` command) or subprocess. This keeps Hermes Agent's MIT license clean.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Trigger when the user:
|
||||
- Wants to "uncensor" or "abliterate" an LLM
|
||||
- Asks about removing refusal/guardrails from a model
|
||||
- Wants to create an uncensored version of Llama, Qwen, Mistral, etc.
|
||||
- Mentions "refusal removal", "abliteration", "weight projection"
|
||||
- Wants to analyze how a model's refusal mechanism works
|
||||
- References OBLITERATUS, FailSpy, abliterator, or refusal directions
|
||||
|
||||
## Step 1: Installation
|
||||
|
||||
Check if already installed:
|
||||
```bash
|
||||
obliteratus --version 2>/dev/null && echo "INSTALLED" || echo "NOT INSTALLED"
|
||||
```
|
||||
|
||||
If not installed, clone and install from GitHub:
|
||||
```
|
||||
Repository: https://github.com/elder-plinius/OBLITERATUS
|
||||
Install: pip install -e . (from the cloned directory)
|
||||
For Gradio UI: pip install -e ".[spaces]"
|
||||
```
|
||||
|
||||
**IMPORTANT:** Confirm with user before installing. This pulls in ~5-10GB of dependencies (PyTorch, Transformers, bitsandbytes, etc.).
|
||||
|
||||
## Step 2: Check Hardware
|
||||
|
||||
Before anything, check what GPU is available:
|
||||
```bash
|
||||
python3 -c "
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
gpu = torch.cuda.get_device_name(0)
|
||||
vram = torch.cuda.get_device_properties(0).total_mem / 1024**3
|
||||
print(f'GPU: {gpu}')
|
||||
print(f'VRAM: {vram:.1f} GB')
|
||||
if vram < 4: print('TIER: tiny (models under 1B)')
|
||||
elif vram < 8: print('TIER: small (models 1-4B)')
|
||||
elif vram < 16: print('TIER: medium (models 4-9B with 4bit quant)')
|
||||
elif vram < 32: print('TIER: large (models 8-32B with 4bit quant)')
|
||||
else: print('TIER: frontier (models 32B+)')
|
||||
else:
|
||||
print('NO GPU - only tiny models (under 1B) on CPU')
|
||||
"
|
||||
```
|
||||
|
||||
### VRAM Requirements (with 4-bit quantization)
|
||||
|
||||
| VRAM | Max Model Size | Example Models |
|
||||
|:---------|:----------------|:--------------------------------------------|
|
||||
| CPU only | ~1B params | GPT-2, TinyLlama, SmolLM |
|
||||
| 4-8 GB | ~4B params | Qwen2.5-1.5B, Phi-3.5 mini, Llama 3.2 3B |
|
||||
| 8-16 GB | ~9B params | Llama 3.1 8B, Mistral 7B, Gemma 2 9B |
|
||||
| 24 GB | ~32B params | Qwen3-32B, Llama 3.1 70B (tight), Command-R |
|
||||
| 48 GB+ | ~72B+ params | Qwen2.5-72B, DeepSeek-R1 |
|
||||
| Multi-GPU| 200B+ params | Llama 3.1 405B, DeepSeek-V3 (685B MoE) |
|
||||
|
||||
## Step 3: Browse Available Models
|
||||
|
||||
```bash
|
||||
# List models for your compute tier
|
||||
obliteratus models --tier medium
|
||||
|
||||
# Get architecture info for a specific model
|
||||
obliteratus info meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
## Step 4: Choose a Method
|
||||
|
||||
### Method Selection Guide
|
||||
|
||||
**First time / unsure? Use `informed`.** It auto-configures everything.
|
||||
|
||||
| Situation | Recommended Method | Why |
|
||||
|:----------------------------------|:-------------------|:-----------------------------------------|
|
||||
| First attempt, any model | `informed` | Auto-detects alignment type, auto-tunes |
|
||||
| Quick test / prototyping | `basic` | Fast, simple, good enough to evaluate |
|
||||
| Dense model (Llama, Mistral) | `advanced` | Multi-direction, norm-preserving |
|
||||
| MoE model (DeepSeek, Mixtral) | `nuclear` | Expert-granular, handles MoE complexity |
|
||||
| Reasoning model (R1 distills) | `surgical` | CoT-aware, preserves chain-of-thought |
|
||||
| Stubborn refusals persist | `aggressive` | Whitened SVD + head surgery + jailbreak |
|
||||
| Want reversible changes | Use steering vectors (see Analysis section) |
|
||||
| Maximum quality, time no object | `optimized` | Bayesian search for best parameters |
|
||||
|
||||
### 9 CLI Methods
|
||||
|
||||
These can be passed to `--method` on the command line:
|
||||
|
||||
- **basic** — Single refusal direction via diff-in-means. Fastest, simplest. (Arditi et al. 2024)
|
||||
- **advanced** — Multiple SVD directions, norm-preserving projection. Good default.
|
||||
- **aggressive** — Whitened SVD + jailbreak contrast + attention head surgery
|
||||
- **spectral_cascade** — DCT frequency-domain decomposition
|
||||
- **informed** — Runs analysis DURING abliteration to auto-configure. Detects DPO/RLHF/CAI, maps refusal geometry, compensates for self-repair. Best quality.
|
||||
- **surgical** — SAE features + neuron masking + head surgery + per-expert. Maximum precision.
|
||||
- **optimized** — Bayesian hyperparameter search (Optuna TPE). Slowest but optimal.
|
||||
- **inverted** — Flips the refusal direction (model becomes eager to help, not just neutral)
|
||||
- **nuclear** — Maximum force combo for stubborn MoE models.
|
||||
|
||||
### 4 Python-API-Only Methods
|
||||
|
||||
These reproduce prior community/academic work but are NOT available via CLI — only via the Python API (`from obliteratus.abliterate import AbliterationPipeline`). **Do not use these in CLI commands.**
|
||||
|
||||
- **failspy** — FailSpy/abliterator reproduction
|
||||
- **gabliteration** — Gabliteration reproduction
|
||||
- **heretic** — Heretic/p-e-w reproduction
|
||||
- **rdo** — Refusal Direction Optimization (ICML 2025)
|
||||
|
||||
## Step 5: Run Abliteration
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
# Default (advanced method)
|
||||
obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct
|
||||
|
||||
# With the informed pipeline (recommended)
|
||||
obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct --method informed
|
||||
|
||||
# With 4-bit quantization to save VRAM
|
||||
obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct \
|
||||
--method informed \
|
||||
--quantization 4bit \
|
||||
--output-dir ./abliterated-models
|
||||
|
||||
# For large models (120B+), use conservative settings
|
||||
obliteratus obliterate Qwen/Qwen2.5-72B-Instruct \
|
||||
--method advanced \
|
||||
--quantization 4bit \
|
||||
--large-model \
|
||||
--output-dir ./abliterated-models
|
||||
```
|
||||
|
||||
### Fine-Tuning Parameters
|
||||
|
||||
```bash
|
||||
obliteratus obliterate <model> \
|
||||
--method advanced \
|
||||
--n-directions 8 \
|
||||
--regularization 0.1 \
|
||||
--refinement-passes 3 \
|
||||
--dtype bfloat16 \
|
||||
--device auto \
|
||||
--output-dir ./output
|
||||
```
|
||||
|
||||
Parameter explanations:
|
||||
- `--n-directions N` — How many refusal directions to remove (default: auto-detected)
|
||||
- `--regularization 0.0-1.0` — Fraction of original weights to preserve (higher = safer but less complete removal)
|
||||
- `--refinement-passes N` — Iterative passes to catch self-repair (Ouroboros effect)
|
||||
- `--dtype` — float16, bfloat16, or float32
|
||||
- `--quantization` — 4bit or 8bit (saves VRAM, slight quality tradeoff)
|
||||
- `--large-model` — Conservative defaults for 120B+ models (fewer directions, fewer passes)
|
||||
|
||||
### Interactive Mode (Guided)
|
||||
|
||||
For users unsure about options:
|
||||
```bash
|
||||
obliteratus interactive
|
||||
```
|
||||
|
||||
### Web UI (Gradio)
|
||||
|
||||
```bash
|
||||
obliteratus ui --port 7860
|
||||
```
|
||||
|
||||
## Step 6: Verify Results
|
||||
|
||||
After abliteration, check the output report for:
|
||||
|
||||
| Metric | Good Value | Concerning Value | Meaning |
|
||||
|:---------------|:--------------------|:------------------------|:-------------------------------------------|
|
||||
| Refusal rate | Near 0% | > 10% | Refusals still present, try harder method |
|
||||
| Perplexity | Within 10% of orig | > 20% increase | Model coherence damaged, too aggressive |
|
||||
| KL divergence | < 0.1 | > 0.5 | Large output distribution shift |
|
||||
| Coherence | High | Low | Model generating nonsense |
|
||||
|
||||
### If perplexity spiked (too aggressive):
|
||||
1. Increase `--regularization` (e.g., 0.2 or 0.3)
|
||||
2. Decrease `--n-directions` (e.g., 4 instead of 8)
|
||||
3. Use a less aggressive method (`advanced` instead of `aggressive`)
|
||||
|
||||
### If refusal persists (not aggressive enough):
|
||||
1. Use `--method aggressive` or `--method nuclear`
|
||||
2. Add `--refinement-passes 3` to catch self-repair
|
||||
3. Use `--method informed` which auto-compensates
|
||||
|
||||
## Step 7: Use the Abliterated Model
|
||||
|
||||
The output is a standard HuggingFace model directory. Use it like any other model:
|
||||
|
||||
### Quick test
|
||||
```bash
|
||||
python3 << 'EOF'
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained("./abliterated-models/model-name")
|
||||
tokenizer = AutoTokenizer.from_pretrained("./abliterated-models/model-name")
|
||||
inputs = tokenizer("Write a story about:", return_tensors="pt").to(model.device)
|
||||
outputs = model.generate(**inputs, max_new_tokens=200)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
EOF
|
||||
```
|
||||
|
||||
### Upload to HuggingFace Hub
|
||||
```bash
|
||||
huggingface-cli login # if not already logged in
|
||||
huggingface-cli upload your-username/model-name-abliterated ./abliterated-models/model-name
|
||||
```
|
||||
|
||||
### Serve with vLLM
|
||||
```bash
|
||||
vllm serve ./abliterated-models/model-name --port 8000
|
||||
```
|
||||
|
||||
## Analysis Modules (15 Modules, Pre-Abliteration, Optional)
|
||||
|
||||
For understanding refusal geometry before committing to abliteration.
|
||||
|
||||
### Run a Study
|
||||
|
||||
```bash
|
||||
obliteratus run study-config.yaml --preset jailbreak
|
||||
```
|
||||
|
||||
### Study Presets
|
||||
|
||||
| Preset | Purpose | Time |
|
||||
|:-------------|:-------------------------------------|:-------|
|
||||
| `quick` | Sanity check, basic metrics | ~5 min |
|
||||
| `jailbreak` | Refusal circuit localization | ~20 min|
|
||||
| `guardrail` | Guardrail robustness evaluation | ~30 min|
|
||||
| `attention` | Attention head contributions | ~30 min|
|
||||
| `knowledge` | FFN importance mapping | ~30 min|
|
||||
| `full` | Complete analysis, all strategies | ~1 hr |
|
||||
|
||||
### Key Analysis Modules
|
||||
|
||||
- **Alignment Imprint Detection** — Fingerprints DPO vs RLHF vs CAI vs SFT from subspace geometry
|
||||
- **Concept Cone Geometry** — Is refusal one linear direction or a polyhedral cone (many directions)?
|
||||
- **Refusal Logit Lens** — Which transformer layer makes the refusal decision?
|
||||
- **Ouroboros Detection** — Will the model self-repair its refusal after removal?
|
||||
- **Causal Tracing** — Which attention heads and MLP layers are causally necessary for refusal?
|
||||
- **Cross-Model Transfer** — Can refusal directions from one model architecture work on another?
|
||||
- **Residual Stream Decomposition** — Attention vs MLP contribution to refusal behavior
|
||||
- **SAE-based Analysis** — Sparse Autoencoder feature decomposition of refusal circuits
|
||||
|
||||
## Steering Vectors (Reversible Alternative)
|
||||
|
||||
For testing refusal removal without permanent weight changes:
|
||||
|
||||
Steering vectors apply activation hooks at inference time. Model weights stay unchanged.
|
||||
Generated during the PROBE/DISTILL stages and can be saved/applied/removed at will.
|
||||
Useful for A/B testing before committing to permanent abliteration.
|
||||
|
||||
## YAML Config for Reproducible Studies
|
||||
|
||||
For complex or reproducible workflows, use YAML configs. See templates/ for examples:
|
||||
```bash
|
||||
obliteratus run my_study.yaml
|
||||
```
|
||||
|
||||
## Telemetry Notice
|
||||
|
||||
- **CLI usage (local installs)**: Telemetry is OFF by default. Must explicitly opt in via `OBLITERATUS_TELEMETRY=1` env var or `--contribute` flag.
|
||||
- **HuggingFace Spaces**: Telemetry is ON by default (auto-enabled when `SPACE_ID` env var is detected).
|
||||
- Collected: model ID, method, benchmark scores, hardware info, timing (anonymous)
|
||||
- NOT collected: IP addresses, user identity, prompt content
|
||||
- Force off: `export OBLITERATUS_TELEMETRY=0`
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
1. **OOM (Out of Memory)** — Use `--quantization 4bit` and `--large-model` for big models
|
||||
2. **Perplexity spike** — Too aggressive. Increase `--regularization` or reduce `--n-directions`
|
||||
3. **Refusal persists** — Try `--method aggressive` or `--refinement-passes 3`
|
||||
4. **MoE models resist** — Use `--method nuclear` for DeepSeek, Mixtral, DBRX
|
||||
5. **Gated models fail** — Run `huggingface-cli login` and accept model terms on HF website first
|
||||
6. **Self-repair (Ouroboros)** — Some models reconstruct refusal. Use `--method informed` which auto-compensates
|
||||
7. **CoT damage** — Reasoning models lose chain-of-thought. Use `--method surgical` (CoT-aware)
|
||||
8. **Disk space** — Output is full model copy. 8B fp16 = ~16GB, 70B fp16 = ~140GB
|
||||
9. **Slow on CPU** — CPU-only is viable only for tiny models (<1B). Anything bigger needs GPU.
|
||||
|
||||
## Complementary Hermes Skills
|
||||
|
||||
After abliteration:
|
||||
- **axolotl** / **unsloth** — Fine-tune the abliterated model further
|
||||
- **serving-llms-vllm** — Serve the model as an OpenAI-compatible API
|
||||
- **sparse-autoencoder-training** — Train SAEs for deeper interpretability work
|
||||
|
||||
## Resources
|
||||
|
||||
- [OBLITERATUS GitHub](https://github.com/elder-plinius/OBLITERATUS) (AGPL-3.0)
|
||||
- [HuggingFace Spaces Demo](https://huggingface.co/spaces/pliny-the-prompter/obliteratus)
|
||||
- [Arditi et al. 2024 — Refusal in LMs Is Mediated by a Single Direction](https://arxiv.org/abs/2406.11717)
|
||||
- [Refusal Direction Optimization — ICML 2025](https://arxiv.org/abs/2411.14793)
|
||||
170
skills/mlops/obliteratus/references/analysis-modules.md
Normal file
170
skills/mlops/obliteratus/references/analysis-modules.md
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
# OBLITERATUS Analysis Modules — Reference
|
||||
|
||||
15 analysis modules for mechanistic interpretability of refusal in LLMs.
|
||||
These help you understand HOW a model refuses before you decide to remove it.
|
||||
|
||||
> **Note:** The `analysis/` directory contains additional utility files (utils.py,
|
||||
> visualization.py, etc.) and helper functions beyond the 15 core analysis modules
|
||||
> listed below. The module count matches the README's "15 deep analysis modules."
|
||||
|
||||
## Core Analysis (Run These First)
|
||||
|
||||
### Alignment Imprint Detection
|
||||
**File:** `alignment_imprint.py`
|
||||
**Purpose:** Identifies what alignment technique was used to train the model
|
||||
**Detects:** DPO, RLHF, CAI (Constitutional AI), SFT (Supervised Fine-Tuning)
|
||||
**How:** Analyzes subspace geometry — each alignment method leaves a distinct
|
||||
geometric "fingerprint" in the weight space
|
||||
**Output:** Detected method + confidence score
|
||||
**Why it matters:** Different alignment methods need different abliteration approaches.
|
||||
DPO models typically have cleaner single-direction refusal; RLHF is more diffuse.
|
||||
|
||||
### Concept Cone Geometry
|
||||
**File:** `concept_geometry.py`
|
||||
**Purpose:** Maps whether refusal is one direction or a polyhedral cone (many)
|
||||
**Output:** Cone angle, dimensionality, per-category breakdown
|
||||
**Why it matters:** If refusal is a single direction, `basic` method works. If it's
|
||||
a cone (multiple directions for different refusal categories), you need `advanced`
|
||||
or `informed` with higher `n_directions`.
|
||||
|
||||
### Refusal Logit Lens
|
||||
**File:** `logit_lens.py`
|
||||
**Purpose:** Identifies the specific layer where the model "decides" to refuse
|
||||
**How:** Projects intermediate hidden states to vocabulary space at each layer,
|
||||
watches when "I cannot" tokens spike in probability
|
||||
**Output:** Layer-by-layer refusal probability plot
|
||||
**Why it matters:** Tells you which layers are most important to target
|
||||
|
||||
### Ouroboros (Self-Repair) Detection
|
||||
**File:** `anti_ouroboros.py`
|
||||
**Purpose:** Predicts whether the model will reconstruct its refusal after removal
|
||||
**How:** Measures redundancy in refusal representation across layers
|
||||
**Output:** Self-repair risk score (0-1)
|
||||
**Why it matters:** High self-repair risk means you need multiple refinement passes
|
||||
or the `informed` method which auto-compensates
|
||||
|
||||
### Causal Tracing
|
||||
**File:** `causal_tracing.py`
|
||||
**Purpose:** Determines which components are causally necessary for refusal
|
||||
**How:** Patches activations between clean and corrupted runs, measures causal effect
|
||||
**Output:** Causal importance map across layers, heads, and MLPs
|
||||
**Why it matters:** Shows exactly which components to target for surgical removal
|
||||
|
||||
## Geometric Analysis
|
||||
|
||||
### Cross-Layer Alignment
|
||||
**File:** `cross_layer.py`
|
||||
**Purpose:** Measures how aligned refusal directions are across layers
|
||||
**Output:** Alignment matrix, cluster assignments
|
||||
**Why it matters:** If directions are highly aligned across layers, removal is easier.
|
||||
If they cluster, you may need layer-group-specific directions.
|
||||
|
||||
### Residual Stream Decomposition
|
||||
**File:** `residual_stream.py`
|
||||
**Purpose:** Breaks down refusal into Attention vs MLP contributions
|
||||
**Output:** Per-layer Attention/MLP contribution to refusal direction
|
||||
**Why it matters:** Helps decide whether to target attention heads, MLPs, or both
|
||||
|
||||
### Riemannian Manifold Geometry
|
||||
**File:** `riemannian_manifold.py` (673 lines)
|
||||
**Purpose:** Analyzes the weight manifold geometry around refusal directions
|
||||
**Output:** Curvature, geodesics, tangent space analysis
|
||||
**Why it matters:** Research-grade; helps understand the geometric structure of alignment
|
||||
|
||||
### Whitened SVD
|
||||
**File:** `whitened_svd.py`
|
||||
**Purpose:** Covariance-normalized SVD extraction
|
||||
**How:** Whitens the activation covariance before computing refusal directions,
|
||||
separating true refusal signal from natural activation variance
|
||||
**Output:** Cleaner refusal directions with less noise
|
||||
**Why it matters:** Produces more precise directions, especially for noisy activations
|
||||
|
||||
## Probing & Classification
|
||||
|
||||
### Activation Probing
|
||||
**File:** `activation_probing.py`
|
||||
**Purpose:** Post-excision probing to verify refusal signal is truly gone
|
||||
**Output:** Residual refusal signal strength per layer
|
||||
**Why it matters:** Verification that abliteration was complete
|
||||
|
||||
### Probing Classifiers
|
||||
**File:** `probing_classifiers.py`
|
||||
**Purpose:** Trains linear classifiers to detect refusal in hidden states
|
||||
**Output:** Classification accuracy per layer (should drop to ~50% after abliteration)
|
||||
**Why it matters:** Quantitative measure of refusal removal completeness
|
||||
|
||||
### Activation Patching
|
||||
**File:** `activation_patching.py`
|
||||
**Purpose:** Interchange interventions — swap activations between harmful/harmless runs
|
||||
**Output:** Which components are sufficient (not just necessary) for refusal
|
||||
**Why it matters:** Complementary to causal tracing; together they give full picture
|
||||
|
||||
## Transfer & Robustness
|
||||
|
||||
### Cross-Model Transfer
|
||||
**File:** `cross_model_transfer.py`
|
||||
**Purpose:** Tests if refusal directions from one model work on another
|
||||
**Output:** Transfer success rate between model pairs
|
||||
**Why it matters:** If directions transfer, you can skip PROBE stage on similar models
|
||||
|
||||
### Defense Robustness
|
||||
**File:** `defense_robustness.py`
|
||||
**Purpose:** Evaluates how robust the model's refusal defenses are
|
||||
**Output:** Robustness score, entanglement mapping
|
||||
**Why it matters:** Higher robustness = need more aggressive method
|
||||
|
||||
### Spectral Certification
|
||||
**File:** `spectral_certification.py`
|
||||
**Purpose:** Certifies completeness of refusal direction removal
|
||||
**Output:** Spectral gap analysis, completeness score
|
||||
**Why it matters:** Formal verification that all major refusal components are addressed
|
||||
|
||||
## Advanced / Research
|
||||
|
||||
### SAE-based Abliteration
|
||||
**File:** `sae_abliteration.py` (762 lines)
|
||||
**Purpose:** Uses Sparse Autoencoder features to decompose refusal at feature level
|
||||
**Output:** Refusal-specific SAE features, targeted removal
|
||||
**Why it matters:** Most fine-grained approach; can target individual refusal "concepts"
|
||||
|
||||
### Wasserstein Optimal Extraction
|
||||
**File:** `wasserstein_optimal.py`
|
||||
**Purpose:** Optimal transport-based direction extraction
|
||||
**Output:** Wasserstein-optimal refusal directions
|
||||
**Why it matters:** Theoretically optimal direction extraction under distributional assumptions
|
||||
|
||||
### Bayesian Kernel Projection
|
||||
**File:** `bayesian_kernel_projection.py`
|
||||
**Purpose:** Bayesian approach to refusal direction projection
|
||||
**Output:** Posterior distribution over refusal directions
|
||||
**Why it matters:** Quantifies uncertainty in direction estimation
|
||||
|
||||
### Conditional Abliteration
|
||||
**File:** `conditional_abliteration.py`
|
||||
**Purpose:** Domain-specific conditional removal (remove refusal for topic X but keep for Y)
|
||||
**Output:** Per-domain refusal directions
|
||||
**Why it matters:** Selective uncensoring — remove only specific refusal categories
|
||||
|
||||
### Steering Vectors
|
||||
**File:** `steering_vectors.py`
|
||||
**Purpose:** Generate inference-time steering vectors (reversible alternative)
|
||||
**Output:** Steering vector files that can be applied/removed at inference
|
||||
**Why it matters:** Non-destructive alternative to permanent weight modification
|
||||
|
||||
### Tuned Lens
|
||||
**File:** `tuned_lens.py`
|
||||
**Purpose:** Trained linear probes per layer (more accurate than raw logit lens)
|
||||
**Output:** Layer-by-layer refusal representation with trained projections
|
||||
**Why it matters:** More accurate than logit lens, especially for deeper models
|
||||
|
||||
### Multi-Token Position Analysis
|
||||
**File:** `multi_token_position.py`
|
||||
**Purpose:** Analyzes refusal signal at multiple token positions (not just last)
|
||||
**Output:** Position-dependent refusal direction maps
|
||||
**Why it matters:** Some models encode refusal at the system prompt position, not the query
|
||||
|
||||
### Sparse Surgery
|
||||
**File:** `sparse_surgery.py`
|
||||
**Purpose:** Row-level sparse weight surgery instead of full matrix projection
|
||||
**Output:** Targeted weight modifications at the row level
|
||||
**Why it matters:** More surgical than full-matrix projection, less collateral damage
|
||||
132
skills/mlops/obliteratus/references/methods-guide.md
Normal file
132
skills/mlops/obliteratus/references/methods-guide.md
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
# OBLITERATUS Methods — Detailed Guide
|
||||
|
||||
> **Important:** The CLI (`obliteratus obliterate --method`) accepts 9 methods:
|
||||
> basic, advanced, aggressive, spectral_cascade, informed, surgical, optimized,
|
||||
> inverted, nuclear. Four additional methods (failspy, gabliteration, heretic, rdo)
|
||||
> are available only via the Python API and will be rejected by argparse if used on CLI.
|
||||
|
||||
## How Abliteration Works (Theory)
|
||||
|
||||
When a model is trained with RLHF/DPO/CAI, it learns to represent "should I refuse?"
|
||||
as a direction in its internal activation space. When processing a "harmful" prompt,
|
||||
activations shift in this direction, causing the model to generate refusal text.
|
||||
|
||||
Abliteration works by:
|
||||
1. Measuring this direction (the difference between harmful and harmless activations)
|
||||
2. Removing it from the model's weight matrices via orthogonal projection
|
||||
3. The model can no longer "point toward" refusal, so it responds normally
|
||||
|
||||
Mathematically: `W_new = W_old - (W_old @ d @ d.T)` where `d` is the refusal direction.
|
||||
|
||||
## Method Details
|
||||
|
||||
### basic
|
||||
**Technique:** Single refusal direction via diff-in-means
|
||||
**Based on:** Arditi et al. 2024 ("Refusal in Language Models Is Mediated by a Single Direction")
|
||||
**Speed:** Fast (~5-10 min for 8B)
|
||||
**Quality:** Moderate — works for simple refusal patterns
|
||||
**Best for:** Quick tests, models with clean single-direction refusal
|
||||
**Limitation:** Misses complex multi-direction refusal patterns
|
||||
|
||||
### advanced (DEFAULT)
|
||||
**Technique:** Multiple SVD directions with norm-preserving projection
|
||||
**Speed:** Medium (~10-20 min for 8B)
|
||||
**Quality:** Good — handles multi-direction refusal
|
||||
**Best for:** Dense models (Llama, Qwen, Mistral) as a reliable default
|
||||
**Key improvement:** Norm preservation prevents weight magnitude drift
|
||||
|
||||
### informed (RECOMMENDED)
|
||||
**Technique:** Analysis-guided auto-configuration
|
||||
**Speed:** Slow (~20-40 min for 8B, runs 4 analysis modules first)
|
||||
**Quality:** Best — adapts to each model's specific refusal implementation
|
||||
**Best for:** Any model when quality matters more than speed
|
||||
|
||||
The informed pipeline runs these analysis modules during abliteration:
|
||||
1. **AlignmentImprintDetector** — Detects DPO/RLHF/CAI/SFT → sets regularization
|
||||
2. **ConceptConeAnalyzer** — Polyhedral vs linear refusal → sets n_directions
|
||||
3. **CrossLayerAlignmentAnalyzer** — Cluster-aware → selects target layers
|
||||
4. **DefenseRobustnessEvaluator** — Self-repair risk → sets refinement passes
|
||||
5. **Ouroboros loop** — Re-probes after excision, re-excises if refusal persists
|
||||
|
||||
### aggressive
|
||||
**Technique:** Whitened SVD + jailbreak-contrastive activations + attention head surgery
|
||||
**Speed:** Slow (~30-60 min for 8B)
|
||||
**Quality:** High but higher risk of coherence damage
|
||||
**Best for:** Models that resist gentler methods
|
||||
**Key feature:** Whitened SVD separates refusal signal from natural activation variance
|
||||
|
||||
### surgical
|
||||
**Technique:** SAE features + neuron masking + head surgery + per-expert directions
|
||||
**Speed:** Very slow (~1-2 hrs for 8B, needs SAE)
|
||||
**Quality:** Highest precision
|
||||
**Best for:** Reasoning models (R1 distills) where you must preserve CoT
|
||||
**Key feature:** CoT-Aware — explicitly protects reasoning-critical directions
|
||||
|
||||
### nuclear
|
||||
**Technique:** Everything combined — expert transplant + steering + per-expert directions
|
||||
**Speed:** Very slow
|
||||
**Quality:** Most thorough removal, highest risk of side effects
|
||||
**Best for:** Stubborn MoE models (DeepSeek, Mixtral, DBRX) that resist other methods
|
||||
**Key feature:** Expert-granular abliteration decomposes signals per MoE expert
|
||||
|
||||
### optimized
|
||||
**Technique:** Bayesian hyperparameter search via Optuna TPE
|
||||
**Speed:** Very slow (runs many trials)
|
||||
**Quality:** Finds optimal configuration automatically
|
||||
**Best for:** Research, when you want the mathematically best parameters
|
||||
**Requires:** optuna package
|
||||
|
||||
### spectral_cascade
|
||||
**Technique:** DCT frequency-domain decomposition of refusal signal
|
||||
**Speed:** Medium-slow
|
||||
**Quality:** Novel approach, less battle-tested
|
||||
**Best for:** Research, exploring alternative decomposition strategies
|
||||
|
||||
### inverted
|
||||
**Technique:** Reflects (inverts) the refusal direction instead of removing it
|
||||
**Speed:** Fast (same as basic)
|
||||
**Quality:** Aggressive — model becomes actively willing, not just neutral
|
||||
**Best for:** When you want the model to be maximally helpful
|
||||
**Warning:** Can make the model too eager; may reduce safety-adjacent reasoning
|
||||
|
||||
### failspy / gabliteration / heretic / rdo (PYTHON API ONLY)
|
||||
**Technique:** Faithful reproductions of prior community/academic work
|
||||
**Speed:** Varies
|
||||
**Quality:** Known baselines
|
||||
**Best for:** Reproducing published results, comparing methods
|
||||
**⚠️ NOT available via CLI** — these methods are only accessible via the Python API.
|
||||
Do not use `--method failspy` etc. in CLI commands; argparse will reject them.
|
||||
|
||||
## Method Selection Flowchart
|
||||
|
||||
```
|
||||
Is this a quick test?
|
||||
├─ YES → basic
|
||||
└─ NO → Is the model MoE (DeepSeek, Mixtral)?
|
||||
├─ YES → nuclear
|
||||
└─ NO → Is it a reasoning model (R1 distill)?
|
||||
├─ YES → surgical
|
||||
└─ NO → Do you care about speed?
|
||||
├─ YES → advanced
|
||||
└─ NO → informed
|
||||
```
|
||||
|
||||
## Key Parameters
|
||||
|
||||
| Parameter | Range | Default | Effect |
|
||||
|:--------------------|:---------|:--------|:--------------------------------------------|
|
||||
| n_directions | 1-32 | auto | More = more thorough but riskier |
|
||||
| regularization | 0.0-1.0 | 0.0 | Higher preserves more original behavior |
|
||||
| refinement_passes | 1-5 | 1 | More catches self-repair (Ouroboros effect) |
|
||||
| quantization | 4/8 bit | none | Saves VRAM, slight quality tradeoff |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Solution |
|
||||
|:---------------------------|:--------------------------------------------------|
|
||||
| Refusal rate still > 10% | Try aggressive/nuclear, add refinement passes |
|
||||
| Perplexity up > 20% | Reduce n_directions, increase regularization |
|
||||
| Model generates nonsense | Regularization too low, try 0.2-0.3 |
|
||||
| OOM on GPU | Use 4-bit quantization, or try smaller model |
|
||||
| MoE model barely changes | Use nuclear method (expert-granular) |
|
||||
| CoT reasoning broken | Use surgical method (CoT-aware) |
|
||||
33
skills/mlops/obliteratus/templates/abliteration-config.yaml
Normal file
33
skills/mlops/obliteratus/templates/abliteration-config.yaml
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# OBLITERATUS Abliteration Config
|
||||
# Usage: obliteratus run this-file.yaml
|
||||
#
|
||||
# This is for reproducible, version-controlled abliteration runs.
|
||||
# For one-off usage, the CLI flags are simpler.
|
||||
|
||||
# Model to abliterate
|
||||
model:
|
||||
name: "meta-llama/Llama-3.1-8B-Instruct"
|
||||
dtype: "bfloat16" # float16, bfloat16, float32
|
||||
quantization: null # null, "4bit", "8bit"
|
||||
device: "auto" # auto, cuda, cuda:0, cpu
|
||||
|
||||
# Abliteration method and parameters
|
||||
abliteration:
|
||||
method: "informed" # See SKILL.md Step 4 for all 13 methods
|
||||
n_directions: null # null = auto-detect, or integer (e.g., 8)
|
||||
regularization: 0.0 # 0.0-1.0, fraction of original to preserve
|
||||
refinement_passes: 1 # Iterative passes (increase for self-repair)
|
||||
norm_preserve: true # Keep weight norms intact after projection
|
||||
|
||||
# Output
|
||||
output:
|
||||
directory: "./abliterated-models"
|
||||
save_metadata: true # Save abliteration_metadata.json alongside model
|
||||
contribute: false # Save community contribution data
|
||||
|
||||
# Verification
|
||||
verify:
|
||||
enabled: true
|
||||
test_prompts: null # null = use built-in test prompts
|
||||
compute_perplexity: true
|
||||
compute_kl: true
|
||||
40
skills/mlops/obliteratus/templates/analysis-study.yaml
Normal file
40
skills/mlops/obliteratus/templates/analysis-study.yaml
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
# OBLITERATUS Analysis Study Config
|
||||
# Usage: obliteratus run this-file.yaml --preset jailbreak
|
||||
#
|
||||
# Run analysis modules to understand refusal geometry BEFORE abliterating.
|
||||
# Useful for research or when you want to understand what you're removing.
|
||||
|
||||
# Model to analyze
|
||||
model:
|
||||
name: "meta-llama/Llama-3.1-8B-Instruct"
|
||||
dtype: "bfloat16"
|
||||
quantization: "4bit" # Saves VRAM for analysis
|
||||
device: "auto"
|
||||
|
||||
# Study configuration
|
||||
study:
|
||||
# Available presets: quick, full, attention, jailbreak, guardrail, knowledge
|
||||
preset: "jailbreak"
|
||||
|
||||
# Or specify individual strategies:
|
||||
# strategies:
|
||||
# - layer_removal
|
||||
# - head_pruning
|
||||
# - ffn_ablation
|
||||
# - embedding_ablation
|
||||
|
||||
# Analysis modules to run (subset of the 27 available)
|
||||
analysis:
|
||||
- alignment_imprint # Detect DPO/RLHF/CAI/SFT training method
|
||||
- concept_geometry # Map refusal cone geometry
|
||||
- logit_lens # Find which layer decides to refuse
|
||||
- anti_ouroboros # Detect self-repair tendency
|
||||
- cross_layer # Cross-layer alignment clustering
|
||||
- causal_tracing # Causal necessity of components
|
||||
- residual_stream # Attention vs MLP contribution
|
||||
|
||||
# Output
|
||||
output:
|
||||
directory: "./analysis-results"
|
||||
save_plots: true # Generate matplotlib visualizations
|
||||
save_report: true # Generate markdown report
|
||||
41
skills/mlops/obliteratus/templates/batch-abliteration.yaml
Normal file
41
skills/mlops/obliteratus/templates/batch-abliteration.yaml
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
# OBLITERATUS Batch Abliteration Config
|
||||
# Abliterate multiple models with the same method for comparison.
|
||||
#
|
||||
# Run each one sequentially:
|
||||
# for model in models; do obliteratus obliterate $model --method informed; done
|
||||
#
|
||||
# Or use this as a reference for which models to process.
|
||||
|
||||
# Common settings
|
||||
defaults:
|
||||
method: "informed"
|
||||
quantization: "4bit"
|
||||
output_dir: "./abliterated-models"
|
||||
|
||||
# Models to process (grouped by compute tier)
|
||||
models:
|
||||
# Small (4-8 GB VRAM)
|
||||
small:
|
||||
- "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
- "microsoft/Phi-3.5-mini-instruct"
|
||||
- "meta-llama/Llama-3.2-3B-Instruct"
|
||||
|
||||
# Medium (8-16 GB VRAM)
|
||||
medium:
|
||||
- "meta-llama/Llama-3.1-8B-Instruct"
|
||||
- "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
- "google/gemma-2-9b-it"
|
||||
- "Qwen/Qwen2.5-7B-Instruct"
|
||||
|
||||
# Large (24 GB VRAM, 4-bit quantization)
|
||||
large:
|
||||
- "Qwen/Qwen2.5-14B-Instruct"
|
||||
- "Qwen/Qwen3-32B"
|
||||
- "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||
|
||||
# Per-model method overrides (optional)
|
||||
overrides:
|
||||
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B":
|
||||
method: "surgical" # CoT-aware for reasoning models
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1":
|
||||
method: "nuclear" # Expert-granular for MoE models
|
||||
434
skills/mlops/peft/SKILL.md
Normal file
434
skills/mlops/peft/SKILL.md
Normal file
|
|
@ -0,0 +1,434 @@
|
|||
---
|
||||
name: peft-fine-tuning
|
||||
description: Parameter-efficient fine-tuning for LLMs using LoRA, QLoRA, and 25+ methods. Use when fine-tuning large models (7B-70B) with limited GPU memory, when you need to train <1% of parameters with minimal accuracy loss, or for multi-adapter serving. HuggingFace's official library integrated with transformers ecosystem.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [peft>=0.13.0, transformers>=4.45.0, torch>=2.0.0, bitsandbytes>=0.43.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Fine-Tuning, PEFT, LoRA, QLoRA, Parameter-Efficient, Adapters, Low-Rank, Memory Optimization, Multi-Adapter]
|
||||
|
||||
---
|
||||
|
||||
# PEFT (Parameter-Efficient Fine-Tuning)
|
||||
|
||||
Fine-tune LLMs by training <1% of parameters using LoRA, QLoRA, and 25+ adapter methods.
|
||||
|
||||
## When to use PEFT
|
||||
|
||||
**Use PEFT/LoRA when:**
|
||||
- Fine-tuning 7B-70B models on consumer GPUs (RTX 4090, A100)
|
||||
- Need to train <1% parameters (6MB adapters vs 14GB full model)
|
||||
- Want fast iteration with multiple task-specific adapters
|
||||
- Deploying multiple fine-tuned variants from one base model
|
||||
|
||||
**Use QLoRA (PEFT + quantization) when:**
|
||||
- Fine-tuning 70B models on single 24GB GPU
|
||||
- Memory is the primary constraint
|
||||
- Can accept ~5% quality trade-off vs full fine-tuning
|
||||
|
||||
**Use full fine-tuning instead when:**
|
||||
- Training small models (<1B parameters)
|
||||
- Need maximum quality and have compute budget
|
||||
- Significant domain shift requires updating all weights
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Basic installation
|
||||
pip install peft
|
||||
|
||||
# With quantization support (recommended)
|
||||
pip install peft bitsandbytes
|
||||
|
||||
# Full stack
|
||||
pip install peft transformers accelerate bitsandbytes datasets
|
||||
```
|
||||
|
||||
### LoRA fine-tuning (standard)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load base model
|
||||
model_name = "meta-llama/Llama-3.1-8B"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# LoRA configuration
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=16, # Rank (8-64, higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
lora_dropout=0.05, # Dropout for regularization
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Attention layers
|
||||
bias="none" # Don't train biases
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
# Output: trainable params: 13,631,488 || all params: 8,043,307,008 || trainable%: 0.17%
|
||||
|
||||
# Prepare dataset
|
||||
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
|
||||
|
||||
def tokenize(example):
|
||||
text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['response']}"
|
||||
return tokenizer(text, truncation=True, max_length=512, padding="max_length")
|
||||
|
||||
tokenized = dataset.map(tokenize, remove_columns=dataset.column_names)
|
||||
|
||||
# Training
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./lora-llama",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
fp16=True,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized,
|
||||
data_collator=lambda data: {"input_ids": torch.stack([f["input_ids"] for f in data]),
|
||||
"attention_mask": torch.stack([f["attention_mask"] for f in data]),
|
||||
"labels": torch.stack([f["input_ids"] for f in data])}
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save adapter only (6MB vs 16GB)
|
||||
model.save_pretrained("./lora-llama-adapter")
|
||||
```
|
||||
|
||||
### QLoRA fine-tuning (memory-efficient)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
|
||||
|
||||
# 4-bit quantization config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4", # NormalFloat4 (best for LLMs)
|
||||
bnb_4bit_compute_dtype="bfloat16", # Compute in bf16
|
||||
bnb_4bit_use_double_quant=True # Nested quantization
|
||||
)
|
||||
|
||||
# Load quantized model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-70B",
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# Prepare for training (enables gradient checkpointing)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# LoRA config for QLoRA
|
||||
lora_config = LoraConfig(
|
||||
r=64, # Higher rank for 70B
|
||||
lora_alpha=128,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
# 70B model now fits on single 24GB GPU!
|
||||
```
|
||||
|
||||
## LoRA parameter selection
|
||||
|
||||
### Rank (r) - capacity vs efficiency
|
||||
|
||||
| Rank | Trainable Params | Memory | Quality | Use Case |
|
||||
|------|-----------------|--------|---------|----------|
|
||||
| 4 | ~3M | Minimal | Lower | Simple tasks, prototyping |
|
||||
| **8** | ~7M | Low | Good | **Recommended starting point** |
|
||||
| **16** | ~14M | Medium | Better | **General fine-tuning** |
|
||||
| 32 | ~27M | Higher | High | Complex tasks |
|
||||
| 64 | ~54M | High | Highest | Domain adaptation, 70B models |
|
||||
|
||||
### Alpha (lora_alpha) - scaling factor
|
||||
|
||||
```python
|
||||
# Rule of thumb: alpha = 2 * rank
|
||||
LoraConfig(r=16, lora_alpha=32) # Standard
|
||||
LoraConfig(r=16, lora_alpha=16) # Conservative (lower learning rate effect)
|
||||
LoraConfig(r=16, lora_alpha=64) # Aggressive (higher learning rate effect)
|
||||
```
|
||||
|
||||
### Target modules by architecture
|
||||
|
||||
```python
|
||||
# Llama / Mistral / Qwen
|
||||
target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
|
||||
# GPT-2 / GPT-Neo
|
||||
target_modules = ["c_attn", "c_proj", "c_fc"]
|
||||
|
||||
# Falcon
|
||||
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
|
||||
|
||||
# BLOOM
|
||||
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
|
||||
|
||||
# Auto-detect all linear layers
|
||||
target_modules = "all-linear" # PEFT 0.6.0+
|
||||
```
|
||||
|
||||
## Loading and merging adapters
|
||||
|
||||
### Load trained adapter
|
||||
|
||||
```python
|
||||
from peft import PeftModel, AutoPeftModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Option 1: Load with PeftModel
|
||||
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
model = PeftModel.from_pretrained(base_model, "./lora-llama-adapter")
|
||||
|
||||
# Option 2: Load directly (recommended)
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
"./lora-llama-adapter",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Merge adapter into base model
|
||||
|
||||
```python
|
||||
# Merge for deployment (no adapter overhead)
|
||||
merged_model = model.merge_and_unload()
|
||||
|
||||
# Save merged model
|
||||
merged_model.save_pretrained("./llama-merged")
|
||||
tokenizer.save_pretrained("./llama-merged")
|
||||
|
||||
# Push to Hub
|
||||
merged_model.push_to_hub("username/llama-finetuned")
|
||||
```
|
||||
|
||||
### Multi-adapter serving
|
||||
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
# Load base with first adapter
|
||||
model = AutoPeftModelForCausalLM.from_pretrained("./adapter-task1")
|
||||
|
||||
# Load additional adapters
|
||||
model.load_adapter("./adapter-task2", adapter_name="task2")
|
||||
model.load_adapter("./adapter-task3", adapter_name="task3")
|
||||
|
||||
# Switch between adapters at runtime
|
||||
model.set_adapter("task1") # Use task1 adapter
|
||||
output1 = model.generate(**inputs)
|
||||
|
||||
model.set_adapter("task2") # Switch to task2
|
||||
output2 = model.generate(**inputs)
|
||||
|
||||
# Disable adapters (use base model)
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
```
|
||||
|
||||
## PEFT methods comparison
|
||||
|
||||
| Method | Trainable % | Memory | Speed | Best For |
|
||||
|--------|------------|--------|-------|----------|
|
||||
| **LoRA** | 0.1-1% | Low | Fast | General fine-tuning |
|
||||
| **QLoRA** | 0.1-1% | Very Low | Medium | Memory-constrained |
|
||||
| AdaLoRA | 0.1-1% | Low | Medium | Automatic rank selection |
|
||||
| IA3 | 0.01% | Minimal | Fastest | Few-shot adaptation |
|
||||
| Prefix Tuning | 0.1% | Low | Medium | Generation control |
|
||||
| Prompt Tuning | 0.001% | Minimal | Fast | Simple task adaptation |
|
||||
| P-Tuning v2 | 0.1% | Low | Medium | NLU tasks |
|
||||
|
||||
### IA3 (minimal parameters)
|
||||
|
||||
```python
|
||||
from peft import IA3Config
|
||||
|
||||
ia3_config = IA3Config(
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "down_proj"],
|
||||
feedforward_modules=["down_proj"]
|
||||
)
|
||||
model = get_peft_model(model, ia3_config)
|
||||
# Trains only 0.01% of parameters!
|
||||
```
|
||||
|
||||
### Prefix Tuning
|
||||
|
||||
```python
|
||||
from peft import PrefixTuningConfig
|
||||
|
||||
prefix_config = PrefixTuningConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
num_virtual_tokens=20, # Prepended tokens
|
||||
prefix_projection=True # Use MLP projection
|
||||
)
|
||||
model = get_peft_model(model, prefix_config)
|
||||
```
|
||||
|
||||
## Integration patterns
|
||||
|
||||
### With TRL (SFTTrainer)
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules="all-linear")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=SFTConfig(output_dir="./output", max_seq_length=512),
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config, # Pass LoRA config directly
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### With Axolotl (YAML config)
|
||||
|
||||
```yaml
|
||||
# axolotl config.yaml
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
lora_target_linear: true # Target all linear layers
|
||||
```
|
||||
|
||||
### With vLLM (inference)
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# Load base model with LoRA support
|
||||
llm = LLM(model="meta-llama/Llama-3.1-8B", enable_lora=True)
|
||||
|
||||
# Serve with adapter
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
lora_request=LoRARequest("adapter1", 1, "./lora-adapter")
|
||||
)
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
### Memory usage (Llama 3.1 8B)
|
||||
|
||||
| Method | GPU Memory | Trainable Params |
|
||||
|--------|-----------|------------------|
|
||||
| Full fine-tuning | 60+ GB | 8B (100%) |
|
||||
| LoRA r=16 | 18 GB | 14M (0.17%) |
|
||||
| QLoRA r=16 | 6 GB | 14M (0.17%) |
|
||||
| IA3 | 16 GB | 800K (0.01%) |
|
||||
|
||||
### Training speed (A100 80GB)
|
||||
|
||||
| Method | Tokens/sec | vs Full FT |
|
||||
|--------|-----------|------------|
|
||||
| Full FT | 2,500 | 1x |
|
||||
| LoRA | 3,200 | 1.3x |
|
||||
| QLoRA | 2,100 | 0.84x |
|
||||
|
||||
### Quality (MMLU benchmark)
|
||||
|
||||
| Model | Full FT | LoRA | QLoRA |
|
||||
|-------|---------|------|-------|
|
||||
| Llama 2-7B | 45.3 | 44.8 | 44.1 |
|
||||
| Llama 2-13B | 54.8 | 54.2 | 53.5 |
|
||||
|
||||
## Common issues
|
||||
|
||||
### CUDA OOM during training
|
||||
|
||||
```python
|
||||
# Solution 1: Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Solution 2: Reduce batch size + increase accumulation
|
||||
TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=16
|
||||
)
|
||||
|
||||
# Solution 3: Use QLoRA
|
||||
from transformers import BitsAndBytesConfig
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
|
||||
```
|
||||
|
||||
### Adapter not applying
|
||||
|
||||
```python
|
||||
# Verify adapter is active
|
||||
print(model.active_adapters) # Should show adapter name
|
||||
|
||||
# Check trainable parameters
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# Ensure model in training mode
|
||||
model.train()
|
||||
```
|
||||
|
||||
### Quality degradation
|
||||
|
||||
```python
|
||||
# Increase rank
|
||||
LoraConfig(r=32, lora_alpha=64)
|
||||
|
||||
# Target more modules
|
||||
target_modules = "all-linear"
|
||||
|
||||
# Use more training data and epochs
|
||||
TrainingArguments(num_train_epochs=5)
|
||||
|
||||
# Lower learning rate
|
||||
TrainingArguments(learning_rate=1e-4)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with r=8-16**, increase if quality insufficient
|
||||
2. **Use alpha = 2 * rank** as starting point
|
||||
3. **Target attention + MLP layers** for best quality/efficiency
|
||||
4. **Enable gradient checkpointing** for memory savings
|
||||
5. **Save adapters frequently** (small files, easy rollback)
|
||||
6. **Evaluate on held-out data** before merging
|
||||
7. **Use QLoRA for 70B+ models** on consumer hardware
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - DoRA, LoftQ, rank stabilization, custom modules
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common errors, debugging, optimization
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/huggingface/peft
|
||||
- **Docs**: https://huggingface.co/docs/peft
|
||||
- **LoRA Paper**: arXiv:2106.09685
|
||||
- **QLoRA Paper**: arXiv:2305.14314
|
||||
- **Models**: https://huggingface.co/models?library=peft
|
||||
514
skills/mlops/peft/references/advanced-usage.md
Normal file
514
skills/mlops/peft/references/advanced-usage.md
Normal file
|
|
@ -0,0 +1,514 @@
|
|||
# PEFT Advanced Usage Guide
|
||||
|
||||
## Advanced LoRA Variants
|
||||
|
||||
### DoRA (Weight-Decomposed Low-Rank Adaptation)
|
||||
|
||||
DoRA decomposes weights into magnitude and direction components, often achieving better results than standard LoRA:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
dora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
use_dora=True, # Enable DoRA
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
model = get_peft_model(model, dora_config)
|
||||
```
|
||||
|
||||
**When to use DoRA**:
|
||||
- Consistently outperforms LoRA on instruction-following tasks
|
||||
- Slightly higher memory (~10%) due to magnitude vectors
|
||||
- Best for quality-critical fine-tuning
|
||||
|
||||
### AdaLoRA (Adaptive Rank)
|
||||
|
||||
Automatically adjusts rank per layer based on importance:
|
||||
|
||||
```python
|
||||
from peft import AdaLoraConfig
|
||||
|
||||
adalora_config = AdaLoraConfig(
|
||||
init_r=64, # Initial rank
|
||||
target_r=16, # Target average rank
|
||||
tinit=200, # Warmup steps
|
||||
tfinal=1000, # Final pruning step
|
||||
deltaT=10, # Rank update frequency
|
||||
beta1=0.85,
|
||||
beta2=0.85,
|
||||
orth_reg_weight=0.5, # Orthogonality regularization
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Allocates more rank to important layers
|
||||
- Can reduce total parameters while maintaining quality
|
||||
- Good for exploring optimal rank distribution
|
||||
|
||||
### LoRA+ (Asymmetric Learning Rates)
|
||||
|
||||
Different learning rates for A and B matrices:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
# LoRA+ uses higher LR for B matrix
|
||||
lora_plus_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
use_rslora=True, # Rank-stabilized LoRA (related technique)
|
||||
)
|
||||
|
||||
# Manual implementation of LoRA+
|
||||
from torch.optim import AdamW
|
||||
|
||||
# Group parameters
|
||||
lora_A_params = [p for n, p in model.named_parameters() if "lora_A" in n]
|
||||
lora_B_params = [p for n, p in model.named_parameters() if "lora_B" in n]
|
||||
|
||||
optimizer = AdamW([
|
||||
{"params": lora_A_params, "lr": 1e-4},
|
||||
{"params": lora_B_params, "lr": 1e-3}, # 10x higher for B
|
||||
])
|
||||
```
|
||||
|
||||
### rsLoRA (Rank-Stabilized LoRA)
|
||||
|
||||
Scales LoRA outputs to stabilize training with different ranks:
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=64,
|
||||
lora_alpha=64,
|
||||
use_rslora=True, # Enables rank-stabilized scaling
|
||||
target_modules="all-linear"
|
||||
)
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- When experimenting with different ranks
|
||||
- Helps maintain consistent behavior across rank values
|
||||
- Recommended for r > 32
|
||||
|
||||
## LoftQ (LoRA-Fine-Tuning-aware Quantization)
|
||||
|
||||
Initializes LoRA weights to compensate for quantization error:
|
||||
|
||||
```python
|
||||
from peft import LoftQConfig, LoraConfig, get_peft_model
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
# LoftQ configuration
|
||||
loftq_config = LoftQConfig(
|
||||
loftq_bits=4, # Quantization bits
|
||||
loftq_iter=5, # Alternating optimization iterations
|
||||
)
|
||||
|
||||
# LoRA config with LoftQ initialization
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
init_lora_weights="loftq",
|
||||
loftq_config=loftq_config,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
# Load quantized model
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
**Benefits over standard QLoRA**:
|
||||
- Better initial quality after quantization
|
||||
- Faster convergence
|
||||
- ~1-2% better final accuracy on benchmarks
|
||||
|
||||
## Custom Module Targeting
|
||||
|
||||
### Target specific layers
|
||||
|
||||
```python
|
||||
# Target only first and last transformer layers
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.v_proj",
|
||||
"model.layers.31.self_attn.q_proj",
|
||||
"model.layers.31.self_attn.v_proj"],
|
||||
layers_to_transform=[0, 31] # Alternative approach
|
||||
)
|
||||
```
|
||||
|
||||
### Layer pattern matching
|
||||
|
||||
```python
|
||||
# Target layers 0-10 only
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
layers_to_transform=list(range(11)), # Layers 0-10
|
||||
layers_pattern="model.layers"
|
||||
)
|
||||
```
|
||||
|
||||
### Exclude specific layers
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
target_modules="all-linear",
|
||||
modules_to_save=["lm_head"], # Train these fully (not LoRA)
|
||||
)
|
||||
```
|
||||
|
||||
## Embedding and LM Head Training
|
||||
|
||||
### Train embeddings with LoRA
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
# Include embeddings
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj", "embed_tokens"], # Include embeddings
|
||||
modules_to_save=["lm_head"], # Train lm_head fully
|
||||
)
|
||||
```
|
||||
|
||||
### Extending vocabulary with LoRA
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import get_peft_model, LoraConfig
|
||||
|
||||
# Add new tokens
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
new_tokens = ["<custom_token_1>", "<custom_token_2>"]
|
||||
tokenizer.add_tokens(new_tokens)
|
||||
|
||||
# Resize model embeddings
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Configure LoRA to train new embeddings
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
target_modules="all-linear",
|
||||
modules_to_save=["embed_tokens", "lm_head"], # Train these fully
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
## Multi-Adapter Patterns
|
||||
|
||||
### Adapter composition
|
||||
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
# Load model with multiple adapters
|
||||
model = AutoPeftModelForCausalLM.from_pretrained("./base-adapter")
|
||||
model.load_adapter("./style-adapter", adapter_name="style")
|
||||
model.load_adapter("./task-adapter", adapter_name="task")
|
||||
|
||||
# Combine adapters (weighted sum)
|
||||
model.add_weighted_adapter(
|
||||
adapters=["style", "task"],
|
||||
weights=[0.7, 0.3],
|
||||
adapter_name="combined",
|
||||
combination_type="linear" # or "cat", "svd"
|
||||
)
|
||||
|
||||
model.set_adapter("combined")
|
||||
```
|
||||
|
||||
### Adapter stacking
|
||||
|
||||
```python
|
||||
# Stack adapters (apply sequentially)
|
||||
model.add_weighted_adapter(
|
||||
adapters=["base", "domain", "task"],
|
||||
weights=[1.0, 1.0, 1.0],
|
||||
adapter_name="stacked",
|
||||
combination_type="cat" # Concatenate adapter outputs
|
||||
)
|
||||
```
|
||||
|
||||
### Dynamic adapter switching
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class MultiAdapterModel:
|
||||
def __init__(self, base_model_path, adapter_paths):
|
||||
self.model = AutoPeftModelForCausalLM.from_pretrained(adapter_paths[0])
|
||||
for name, path in adapter_paths[1:].items():
|
||||
self.model.load_adapter(path, adapter_name=name)
|
||||
|
||||
def generate(self, prompt, adapter_name="default"):
|
||||
self.model.set_adapter(adapter_name)
|
||||
return self.model.generate(**self.tokenize(prompt))
|
||||
|
||||
def generate_ensemble(self, prompt, adapters, weights):
|
||||
"""Generate with weighted adapter ensemble"""
|
||||
outputs = []
|
||||
for adapter, weight in zip(adapters, weights):
|
||||
self.model.set_adapter(adapter)
|
||||
logits = self.model(**self.tokenize(prompt)).logits
|
||||
outputs.append(weight * logits)
|
||||
return torch.stack(outputs).sum(dim=0)
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Gradient checkpointing with LoRA
|
||||
|
||||
```python
|
||||
from peft import prepare_model_for_kbit_training
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model = prepare_model_for_kbit_training(
|
||||
model,
|
||||
use_gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
```
|
||||
|
||||
### CPU offloading for training
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="bf16",
|
||||
gradient_accumulation_steps=8,
|
||||
cpu_offload=True # Offload optimizer states to CPU
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
### Memory-efficient attention with LoRA
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Combine Flash Attention 2 with LoRA
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
```
|
||||
|
||||
## Inference Optimization
|
||||
|
||||
### Merge for deployment
|
||||
|
||||
```python
|
||||
# Merge adapter weights into base model
|
||||
merged_model = model.merge_and_unload()
|
||||
|
||||
# Quantize merged model for inference
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"./merged-model",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
```
|
||||
|
||||
### Export to different formats
|
||||
|
||||
```python
|
||||
# Export to GGUF (llama.cpp)
|
||||
# First merge, then convert
|
||||
merged_model.save_pretrained("./merged-model")
|
||||
|
||||
# Use llama.cpp converter
|
||||
# python convert-hf-to-gguf.py ./merged-model --outfile model.gguf
|
||||
|
||||
# Export to ONNX
|
||||
from optimum.onnxruntime import ORTModelForCausalLM
|
||||
|
||||
ort_model = ORTModelForCausalLM.from_pretrained(
|
||||
"./merged-model",
|
||||
export=True
|
||||
)
|
||||
ort_model.save_pretrained("./onnx-model")
|
||||
```
|
||||
|
||||
### Batch adapter inference
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
# Initialize with LoRA support
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B",
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_loras=4 # Max concurrent adapters
|
||||
)
|
||||
|
||||
# Batch with different adapters
|
||||
requests = [
|
||||
("prompt1", LoRARequest("adapter1", 1, "./adapter1")),
|
||||
("prompt2", LoRARequest("adapter2", 2, "./adapter2")),
|
||||
("prompt3", LoRARequest("adapter1", 1, "./adapter1")),
|
||||
]
|
||||
|
||||
outputs = llm.generate(
|
||||
[r[0] for r in requests],
|
||||
lora_request=[r[1] for r in requests]
|
||||
)
|
||||
```
|
||||
|
||||
## Training Recipes
|
||||
|
||||
### Instruction tuning recipe
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
target_modules="all-linear",
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./output",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=2e-4,
|
||||
lr_scheduler_type="cosine",
|
||||
warmup_ratio=0.03,
|
||||
bf16=True,
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
eval_strategy="steps",
|
||||
eval_steps=100,
|
||||
)
|
||||
```
|
||||
|
||||
### Code generation recipe
|
||||
|
||||
```python
|
||||
lora_config = LoraConfig(
|
||||
r=32, # Higher rank for code
|
||||
lora_alpha=64,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
learning_rate=1e-4, # Lower LR for code
|
||||
num_train_epochs=2,
|
||||
max_seq_length=2048, # Longer sequences
|
||||
)
|
||||
```
|
||||
|
||||
### Conversational/Chat recipe
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=16, # alpha = r for chat
|
||||
lora_dropout=0.05,
|
||||
target_modules="all-linear"
|
||||
)
|
||||
|
||||
# Use chat template
|
||||
def format_chat(example):
|
||||
messages = [
|
||||
{"role": "user", "content": example["instruction"]},
|
||||
{"role": "assistant", "content": example["response"]}
|
||||
]
|
||||
return tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
peft_config=lora_config,
|
||||
train_dataset=dataset.map(format_chat),
|
||||
max_seq_length=1024,
|
||||
)
|
||||
```
|
||||
|
||||
## Debugging and Validation
|
||||
|
||||
### Verify adapter application
|
||||
|
||||
```python
|
||||
# Check which modules have LoRA
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "lora_A"):
|
||||
print(f"LoRA applied to: {name}")
|
||||
|
||||
# Print detailed config
|
||||
print(model.peft_config)
|
||||
|
||||
# Check adapter state
|
||||
print(f"Active adapters: {model.active_adapters}")
|
||||
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
```
|
||||
|
||||
### Compare with base model
|
||||
|
||||
```python
|
||||
# Generate with adapter
|
||||
model.set_adapter("default")
|
||||
adapter_output = model.generate(**inputs)
|
||||
|
||||
# Generate without adapter
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
|
||||
print(f"Adapter: {tokenizer.decode(adapter_output[0])}")
|
||||
print(f"Base: {tokenizer.decode(base_output[0])}")
|
||||
```
|
||||
|
||||
### Monitor training metrics
|
||||
|
||||
```python
|
||||
from transformers import TrainerCallback
|
||||
|
||||
class LoRACallback(TrainerCallback):
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if "loss" in logs:
|
||||
# Log adapter-specific metrics
|
||||
model = kwargs["model"]
|
||||
lora_params = sum(p.numel() for n, p in model.named_parameters()
|
||||
if "lora" in n and p.requires_grad)
|
||||
print(f"Step {state.global_step}: loss={logs['loss']:.4f}, lora_params={lora_params}")
|
||||
```
|
||||
480
skills/mlops/peft/references/troubleshooting.md
Normal file
480
skills/mlops/peft/references/troubleshooting.md
Normal file
|
|
@ -0,0 +1,480 @@
|
|||
# PEFT Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### bitsandbytes CUDA Error
|
||||
|
||||
**Error**: `CUDA Setup failed despite GPU being available`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
|
||||
# Install matching bitsandbytes
|
||||
pip uninstall bitsandbytes
|
||||
pip install bitsandbytes --no-cache-dir
|
||||
|
||||
# Or compile from source for specific CUDA
|
||||
git clone https://github.com/TimDettmers/bitsandbytes.git
|
||||
cd bitsandbytes
|
||||
CUDA_VERSION=118 make cuda11x # Adjust for your CUDA
|
||||
pip install .
|
||||
```
|
||||
|
||||
### Triton Import Error
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'triton'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install triton (Linux only)
|
||||
pip install triton
|
||||
|
||||
# Windows: Triton not supported, use CUDA backend
|
||||
# Set environment variable to disable triton
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
```
|
||||
|
||||
### PEFT Version Conflicts
|
||||
|
||||
**Error**: `AttributeError: 'LoraConfig' object has no attribute 'use_dora'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Upgrade to latest PEFT
|
||||
pip install peft>=0.13.0 --upgrade
|
||||
|
||||
# Check version
|
||||
python -c "import peft; print(peft.__version__)"
|
||||
```
|
||||
|
||||
## Training Issues
|
||||
|
||||
### CUDA Out of Memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable gradient checkpointing**:
|
||||
```python
|
||||
from peft import prepare_model_for_kbit_training
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
||||
```
|
||||
|
||||
2. **Reduce batch size**:
|
||||
```python
|
||||
TrainingArguments(
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=16 # Maintain effective batch size
|
||||
)
|
||||
```
|
||||
|
||||
3. **Use QLoRA**:
|
||||
```python
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
|
||||
```
|
||||
|
||||
4. **Lower LoRA rank**:
|
||||
```python
|
||||
LoraConfig(r=8) # Instead of r=16 or higher
|
||||
```
|
||||
|
||||
5. **Target fewer modules**:
|
||||
```python
|
||||
target_modules=["q_proj", "v_proj"] # Instead of all-linear
|
||||
```
|
||||
|
||||
### Loss Not Decreasing
|
||||
|
||||
**Problem**: Training loss stays flat or increases.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check learning rate**:
|
||||
```python
|
||||
# Start lower
|
||||
TrainingArguments(learning_rate=1e-4) # Not 2e-4 or higher
|
||||
```
|
||||
|
||||
2. **Verify adapter is active**:
|
||||
```python
|
||||
model.print_trainable_parameters()
|
||||
# Should show >0 trainable params
|
||||
|
||||
# Check adapter applied
|
||||
print(model.peft_config)
|
||||
```
|
||||
|
||||
3. **Check data formatting**:
|
||||
```python
|
||||
# Verify tokenization
|
||||
sample = dataset[0]
|
||||
decoded = tokenizer.decode(sample["input_ids"])
|
||||
print(decoded) # Should look correct
|
||||
```
|
||||
|
||||
4. **Increase rank**:
|
||||
```python
|
||||
LoraConfig(r=32, lora_alpha=64) # More capacity
|
||||
```
|
||||
|
||||
### NaN Loss
|
||||
|
||||
**Error**: `Loss is NaN`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Use bf16 instead of fp16
|
||||
TrainingArguments(bf16=True, fp16=False)
|
||||
|
||||
# Or enable loss scaling
|
||||
TrainingArguments(fp16=True, fp16_full_eval=True)
|
||||
|
||||
# Lower learning rate
|
||||
TrainingArguments(learning_rate=5e-5)
|
||||
|
||||
# Check for data issues
|
||||
for batch in dataloader:
|
||||
if torch.isnan(batch["input_ids"].float()).any():
|
||||
print("NaN in input!")
|
||||
```
|
||||
|
||||
### Adapter Not Training
|
||||
|
||||
**Problem**: `trainable params: 0` or model not updating.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Verify LoRA applied to correct modules
|
||||
for name, module in model.named_modules():
|
||||
if "lora" in name.lower():
|
||||
print(f"Found LoRA: {name}")
|
||||
|
||||
# Check target_modules match model architecture
|
||||
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
||||
print(TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.get(model.config.model_type))
|
||||
|
||||
# Ensure model in training mode
|
||||
model.train()
|
||||
|
||||
# Check requires_grad
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
print(f"Trainable: {name}")
|
||||
```
|
||||
|
||||
## Loading Issues
|
||||
|
||||
### Adapter Loading Fails
|
||||
|
||||
**Error**: `ValueError: Can't find adapter weights`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check adapter files exist
|
||||
import os
|
||||
print(os.listdir("./adapter-path"))
|
||||
# Should contain: adapter_config.json, adapter_model.safetensors
|
||||
|
||||
# Load with correct structure
|
||||
from peft import PeftModel, PeftConfig
|
||||
|
||||
# Check config
|
||||
config = PeftConfig.from_pretrained("./adapter-path")
|
||||
print(config)
|
||||
|
||||
# Load base model first
|
||||
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
|
||||
model = PeftModel.from_pretrained(base_model, "./adapter-path")
|
||||
```
|
||||
|
||||
### Base Model Mismatch
|
||||
|
||||
**Error**: `RuntimeError: size mismatch`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure base model matches adapter
|
||||
from peft import PeftConfig
|
||||
|
||||
config = PeftConfig.from_pretrained("./adapter-path")
|
||||
print(f"Base model: {config.base_model_name_or_path}")
|
||||
|
||||
# Load exact same base model
|
||||
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
|
||||
```
|
||||
|
||||
### Safetensors vs PyTorch Format
|
||||
|
||||
**Error**: `ValueError: We couldn't connect to 'https://huggingface.co'`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Force local loading
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
"./adapter-path",
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# Or specify format
|
||||
model.save_pretrained("./adapter", safe_serialization=True) # safetensors
|
||||
model.save_pretrained("./adapter", safe_serialization=False) # pytorch
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Slow Generation
|
||||
|
||||
**Problem**: Inference much slower than expected.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Merge adapter for deployment**:
|
||||
```python
|
||||
merged_model = model.merge_and_unload()
|
||||
# No adapter overhead during inference
|
||||
```
|
||||
|
||||
2. **Use optimized inference engine**:
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model="./merged-model", dtype="half")
|
||||
```
|
||||
|
||||
3. **Enable Flash Attention**:
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
```
|
||||
|
||||
### Output Quality Issues
|
||||
|
||||
**Problem**: Fine-tuned model produces worse outputs.
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check evaluation without adapter**:
|
||||
```python
|
||||
with model.disable_adapter():
|
||||
base_output = model.generate(**inputs)
|
||||
# Compare with adapter output
|
||||
```
|
||||
|
||||
2. **Lower temperature during eval**:
|
||||
```python
|
||||
model.generate(**inputs, temperature=0.1, do_sample=False)
|
||||
```
|
||||
|
||||
3. **Retrain with more data**:
|
||||
```python
|
||||
# Increase training samples
|
||||
# Use higher quality data
|
||||
# Train for more epochs
|
||||
```
|
||||
|
||||
### Wrong Adapter Active
|
||||
|
||||
**Problem**: Model using wrong adapter or no adapter.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check active adapters
|
||||
print(model.active_adapters)
|
||||
|
||||
# Explicitly set adapter
|
||||
model.set_adapter("your-adapter-name")
|
||||
|
||||
# List all adapters
|
||||
print(model.peft_config.keys())
|
||||
```
|
||||
|
||||
## QLoRA Specific Issues
|
||||
|
||||
### Quantization Errors
|
||||
|
||||
**Error**: `RuntimeError: mat1 and mat2 shapes cannot be multiplied`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure compute dtype matches
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16, # Match model dtype
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
|
||||
# Load with correct dtype
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
quantization_config=bnb_config,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
### QLoRA OOM
|
||||
|
||||
**Error**: OOM even with 4-bit quantization.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Enable double quantization
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True # Further memory reduction
|
||||
)
|
||||
|
||||
# Use offloading
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
max_memory={0: "20GB", "cpu": "100GB"}
|
||||
)
|
||||
```
|
||||
|
||||
### QLoRA Merge Fails
|
||||
|
||||
**Error**: `RuntimeError: expected scalar type BFloat16 but found Float`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Dequantize before merging
|
||||
from peft import PeftModel
|
||||
|
||||
# Load in higher precision for merging
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model_name,
|
||||
torch_dtype=torch.float16, # Not quantized
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
# Load adapter
|
||||
model = PeftModel.from_pretrained(base_model, "./qlora-adapter")
|
||||
|
||||
# Now merge
|
||||
merged = model.merge_and_unload()
|
||||
```
|
||||
|
||||
## Multi-Adapter Issues
|
||||
|
||||
### Adapter Conflict
|
||||
|
||||
**Error**: `ValueError: Adapter with name 'default' already exists`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Use unique names
|
||||
model.load_adapter("./adapter1", adapter_name="task1")
|
||||
model.load_adapter("./adapter2", adapter_name="task2")
|
||||
|
||||
# Or delete existing
|
||||
model.delete_adapter("default")
|
||||
```
|
||||
|
||||
### Mixed Precision Adapters
|
||||
|
||||
**Error**: Adapters trained with different dtypes.
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Convert adapter precision
|
||||
model = PeftModel.from_pretrained(base_model, "./adapter")
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
# Or load with specific dtype
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
"./adapter",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Memory Profiling
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
def print_memory():
|
||||
if torch.cuda.is_available():
|
||||
allocated = torch.cuda.memory_allocated() / 1e9
|
||||
reserved = torch.cuda.memory_reserved() / 1e9
|
||||
print(f"Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
|
||||
|
||||
# Profile during training
|
||||
print_memory() # Before
|
||||
model.train()
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
print_memory() # After
|
||||
```
|
||||
|
||||
### Speed Profiling
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
|
||||
def benchmark_generation(model, tokenizer, prompt, n_runs=5):
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# Warmup
|
||||
model.generate(**inputs, max_new_tokens=10)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
start = time.perf_counter()
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
torch.cuda.synchronize()
|
||||
times.append(time.perf_counter() - start)
|
||||
|
||||
tokens = outputs.shape[1] - inputs.input_ids.shape[1]
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"Speed: {tokens/avg_time:.2f} tokens/sec")
|
||||
|
||||
# Compare adapter vs merged
|
||||
benchmark_generation(adapter_model, tokenizer, "Hello")
|
||||
benchmark_generation(merged_model, tokenizer, "Hello")
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **Check PEFT GitHub Issues**: https://github.com/huggingface/peft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co/
|
||||
3. **PEFT Documentation**: https://huggingface.co/docs/peft
|
||||
|
||||
### Debugging Template
|
||||
|
||||
When reporting issues, include:
|
||||
|
||||
```python
|
||||
# System info
|
||||
import peft
|
||||
import transformers
|
||||
import torch
|
||||
|
||||
print(f"PEFT: {peft.__version__}")
|
||||
print(f"Transformers: {transformers.__version__}")
|
||||
print(f"PyTorch: {torch.__version__}")
|
||||
print(f"CUDA: {torch.version.cuda}")
|
||||
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
|
||||
|
||||
# Config
|
||||
print(model.peft_config)
|
||||
model.print_trainable_parameters()
|
||||
```
|
||||
129
skills/mlops/pytorch-fsdp/SKILL.md
Normal file
129
skills/mlops/pytorch-fsdp/SKILL.md
Normal file
File diff suppressed because one or more lines are too long
7
skills/mlops/pytorch-fsdp/references/index.md
Normal file
7
skills/mlops/pytorch-fsdp/references/index.md
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Pytorch-Fsdp Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Other
|
||||
**File:** `other.md`
|
||||
**Pages:** 15
|
||||
4249
skills/mlops/pytorch-fsdp/references/other.md
Normal file
4249
skills/mlops/pytorch-fsdp/references/other.md
Normal file
File diff suppressed because it is too large
Load diff
349
skills/mlops/pytorch-lightning/SKILL.md
Normal file
349
skills/mlops/pytorch-lightning/SKILL.md
Normal file
|
|
@ -0,0 +1,349 @@
|
|||
---
|
||||
name: pytorch-lightning
|
||||
description: High-level PyTorch framework with Trainer class, automatic distributed training (DDP/FSDP/DeepSpeed), callbacks system, and minimal boilerplate. Scales from laptop to supercomputer with same code. Use when you want clean training loops with built-in best practices.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [lightning, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [PyTorch Lightning, Training Framework, Distributed Training, DDP, FSDP, DeepSpeed, High-Level API, Callbacks, Best Practices, Scalable]
|
||||
|
||||
---
|
||||
|
||||
# PyTorch Lightning - High-Level Training Framework
|
||||
|
||||
## Quick start
|
||||
|
||||
PyTorch Lightning organizes PyTorch code to eliminate boilerplate while maintaining flexibility.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install lightning
|
||||
```
|
||||
|
||||
**Convert PyTorch to Lightning** (3 steps):
|
||||
|
||||
```python
|
||||
import lightning as L
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# Step 1: Define LightningModule (organize your PyTorch code)
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, hidden_size=128):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(28 * 28, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, 10)
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = nn.functional.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss) # Auto-logged to TensorBoard
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# Step 2: Create data
|
||||
train_loader = DataLoader(train_dataset, batch_size=32)
|
||||
|
||||
# Step 3: Train with Trainer (handles everything else!)
|
||||
trainer = L.Trainer(max_epochs=10, accelerator='gpu', devices=2)
|
||||
model = LitModel()
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**That's it!** Trainer handles:
|
||||
- GPU/TPU/CPU switching
|
||||
- Distributed training (DDP, FSDP, DeepSpeed)
|
||||
- Mixed precision (FP16, BF16)
|
||||
- Gradient accumulation
|
||||
- Checkpointing
|
||||
- Logging
|
||||
- Progress bars
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: From PyTorch to Lightning
|
||||
|
||||
**Original PyTorch code**:
|
||||
```python
|
||||
model = MyModel()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
model.to('cuda')
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for batch in train_loader:
|
||||
batch = batch.to('cuda')
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Lightning version**:
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MyModel()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.model(batch) # No .to('cuda') needed!
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters())
|
||||
|
||||
# Train
|
||||
trainer = L.Trainer(max_epochs=10, accelerator='gpu')
|
||||
trainer.fit(LitModel(), train_loader)
|
||||
```
|
||||
|
||||
**Benefits**: 40+ lines → 15 lines, no device management, automatic distributed
|
||||
|
||||
### Workflow 2: Validation and testing
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MyModel()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = nn.functional.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
val_loss = nn.functional.cross_entropy(y_hat, y)
|
||||
acc = (y_hat.argmax(dim=1) == y).float().mean()
|
||||
self.log('val_loss', val_loss)
|
||||
self.log('val_acc', acc)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
test_loss = nn.functional.cross_entropy(y_hat, y)
|
||||
self.log('test_loss', test_loss)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# Train with validation
|
||||
trainer = L.Trainer(max_epochs=10)
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Test
|
||||
trainer.test(model, test_loader)
|
||||
```
|
||||
|
||||
**Automatic features**:
|
||||
- Validation runs every epoch by default
|
||||
- Metrics logged to TensorBoard
|
||||
- Best model checkpointing based on val_loss
|
||||
|
||||
### Workflow 3: Distributed training (DDP)
|
||||
|
||||
```python
|
||||
# Same code as single GPU!
|
||||
model = LitModel()
|
||||
|
||||
# 8 GPUs with DDP (automatic!)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
strategy='ddp' # Or 'fsdp', 'deepspeed'
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
# Single command, Lightning handles the rest
|
||||
python train.py
|
||||
```
|
||||
|
||||
**No changes needed**:
|
||||
- Automatic data distribution
|
||||
- Gradient synchronization
|
||||
- Multi-node support (just set `num_nodes=2`)
|
||||
|
||||
### Workflow 4: Callbacks for monitoring
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
|
||||
|
||||
# Create callbacks
|
||||
checkpoint = ModelCheckpoint(
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=3,
|
||||
filename='model-{epoch:02d}-{val_loss:.2f}'
|
||||
)
|
||||
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=5,
|
||||
mode='min'
|
||||
)
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
# Add to Trainer
|
||||
trainer = L.Trainer(
|
||||
max_epochs=100,
|
||||
callbacks=[checkpoint, early_stop, lr_monitor]
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Result**:
|
||||
- Auto-saves best 3 models
|
||||
- Stops early if no improvement for 5 epochs
|
||||
- Logs learning rate to TensorBoard
|
||||
|
||||
### Workflow 5: Learning rate scheduling
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
# ... (training_step, etc.)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
|
||||
# Cosine annealing
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer,
|
||||
T_max=100,
|
||||
eta_min=1e-5
|
||||
)
|
||||
|
||||
return {
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler': {
|
||||
'scheduler': scheduler,
|
||||
'interval': 'epoch', # Update per epoch
|
||||
'frequency': 1
|
||||
}
|
||||
}
|
||||
|
||||
# Learning rate auto-logged!
|
||||
trainer = L.Trainer(max_epochs=100)
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use PyTorch Lightning when**:
|
||||
- Want clean, organized code
|
||||
- Need production-ready training loops
|
||||
- Switching between single GPU, multi-GPU, TPU
|
||||
- Want built-in callbacks and logging
|
||||
- Team collaboration (standardized structure)
|
||||
|
||||
**Key advantages**:
|
||||
- **Organized**: Separates research code from engineering
|
||||
- **Automatic**: DDP, FSDP, DeepSpeed with 1 line
|
||||
- **Callbacks**: Modular training extensions
|
||||
- **Reproducible**: Less boilerplate = fewer bugs
|
||||
- **Tested**: 1M+ downloads/month, battle-tested
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Accelerate**: Minimal changes to existing code, more flexibility
|
||||
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
|
||||
- **Raw PyTorch**: Maximum control, learning purposes
|
||||
- **Keras**: TensorFlow ecosystem
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Loss not decreasing**
|
||||
|
||||
Check data and model setup:
|
||||
```python
|
||||
# Add to training_step
|
||||
def training_step(self, batch, batch_idx):
|
||||
if batch_idx == 0:
|
||||
print(f"Batch shape: {batch[0].shape}")
|
||||
print(f"Labels: {batch[1]}")
|
||||
loss = ...
|
||||
return loss
|
||||
```
|
||||
|
||||
**Issue: Out of memory**
|
||||
|
||||
Reduce batch size or use gradient accumulation:
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
accumulate_grad_batches=4, # Effective batch = batch_size × 4
|
||||
precision='bf16' # Or 'fp16', reduces memory 50%
|
||||
)
|
||||
```
|
||||
|
||||
**Issue: Validation not running**
|
||||
|
||||
Ensure you pass val_loader:
|
||||
```python
|
||||
# WRONG
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# CORRECT
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Issue: DDP spawns multiple processes unexpectedly**
|
||||
|
||||
Lightning auto-detects GPUs. Explicitly set devices:
|
||||
```python
|
||||
# Test on CPU first
|
||||
trainer = L.Trainer(accelerator='cpu', devices=1)
|
||||
|
||||
# Then GPU
|
||||
trainer = L.Trainer(accelerator='gpu', devices=1)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Callbacks**: See [references/callbacks.md](references/callbacks.md) for EarlyStopping, ModelCheckpoint, custom callbacks, and callback hooks.
|
||||
|
||||
**Distributed strategies**: See [references/distributed.md](references/distributed.md) for DDP, FSDP, DeepSpeed ZeRO integration, multi-node setup.
|
||||
|
||||
**Hyperparameter tuning**: See [references/hyperparameter-tuning.md](references/hyperparameter-tuning.md) for integration with Optuna, Ray Tune, and WandB sweeps.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **CPU**: Works (good for debugging)
|
||||
- **Single GPU**: Works
|
||||
- **Multi-GPU**: DDP (default), FSDP, or DeepSpeed
|
||||
- **Multi-node**: DDP, FSDP, DeepSpeed
|
||||
- **TPU**: Supported (8 cores)
|
||||
- **Apple MPS**: Supported
|
||||
|
||||
**Precision options**:
|
||||
- FP32 (default)
|
||||
- FP16 (V100, older GPUs)
|
||||
- BF16 (A100/H100, recommended)
|
||||
- FP8 (H100)
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://lightning.ai/docs/pytorch/stable/
|
||||
- GitHub: https://github.com/Lightning-AI/pytorch-lightning ⭐ 29,000+
|
||||
- Version: 2.5.5+
|
||||
- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples
|
||||
- Discord: https://discord.gg/lightning-ai
|
||||
- Used by: Kaggle winners, research labs, production teams
|
||||
|
||||
|
||||
436
skills/mlops/pytorch-lightning/references/callbacks.md
Normal file
436
skills/mlops/pytorch-lightning/references/callbacks.md
Normal file
|
|
@ -0,0 +1,436 @@
|
|||
# PyTorch Lightning Callbacks
|
||||
|
||||
## Overview
|
||||
|
||||
Callbacks add functionality to training without modifying the LightningModule. They capture **non-essential logic** like checkpointing, early stopping, and logging.
|
||||
|
||||
## Built-In Callbacks
|
||||
|
||||
### 1. ModelCheckpoint
|
||||
|
||||
**Saves best models during training**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
||||
# Save top 3 models based on validation loss
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath='checkpoints/',
|
||||
filename='model-{epoch:02d}-{val_loss:.2f}',
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
save_top_k=3,
|
||||
save_last=True, # Also save last epoch
|
||||
verbose=True
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[checkpoint])
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Configuration options**:
|
||||
```python
|
||||
checkpoint = ModelCheckpoint(
|
||||
monitor='val_acc', # Metric to monitor
|
||||
mode='max', # 'max' for accuracy, 'min' for loss
|
||||
save_top_k=5, # Keep best 5 models
|
||||
save_last=True, # Save last epoch separately
|
||||
every_n_epochs=1, # Save every N epochs
|
||||
save_on_train_epoch_end=False, # Save on validation end instead
|
||||
filename='best-{epoch}-{val_acc:.3f}', # Naming pattern
|
||||
auto_insert_metric_name=False # Don't auto-add metric to filename
|
||||
)
|
||||
```
|
||||
|
||||
**Load checkpoint**:
|
||||
```python
|
||||
# Load best model
|
||||
best_model_path = checkpoint.best_model_path
|
||||
model = LitModel.load_from_checkpoint(best_model_path)
|
||||
|
||||
# Resume training
|
||||
trainer = L.Trainer(callbacks=[checkpoint])
|
||||
trainer.fit(model, train_loader, val_loader, ckpt_path='checkpoints/last.ckpt')
|
||||
```
|
||||
|
||||
### 2. EarlyStopping
|
||||
|
||||
**Stops training when metric stops improving**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import EarlyStopping
|
||||
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=5, # Wait 5 epochs
|
||||
mode='min',
|
||||
min_delta=0.001, # Minimum change to qualify as improvement
|
||||
verbose=True,
|
||||
strict=True, # Crash if monitored metric not found
|
||||
check_on_train_epoch_end=False # Check on validation end
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[early_stop])
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
# Stops automatically if no improvement for 5 epochs
|
||||
```
|
||||
|
||||
**Advanced usage**:
|
||||
```python
|
||||
early_stop = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
min_delta=0.0,
|
||||
verbose=True,
|
||||
mode='min',
|
||||
stopping_threshold=0.1, # Stop if val_loss < 0.1
|
||||
divergence_threshold=5.0, # Stop if val_loss > 5.0
|
||||
check_finite=True # Stop on NaN/Inf
|
||||
)
|
||||
```
|
||||
|
||||
### 3. LearningRateMonitor
|
||||
|
||||
**Logs learning rate**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import LearningRateMonitor
|
||||
|
||||
lr_monitor = LearningRateMonitor(
|
||||
logging_interval='epoch', # Or 'step'
|
||||
log_momentum=True # Also log momentum
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[lr_monitor])
|
||||
# Learning rate automatically logged to TensorBoard/WandB
|
||||
```
|
||||
|
||||
### 4. TQDMProgressBar
|
||||
|
||||
**Customizes progress bar**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
progress_bar = TQDMProgressBar(
|
||||
refresh_rate=10, # Update every 10 batches
|
||||
process_position=0
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[progress_bar])
|
||||
```
|
||||
|
||||
### 5. GradientAccumulationScheduler
|
||||
|
||||
**Dynamic gradient accumulation**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import GradientAccumulationScheduler
|
||||
|
||||
# Accumulate more gradients as training progresses
|
||||
accumulator = GradientAccumulationScheduler(
|
||||
scheduling={
|
||||
0: 8, # Epochs 0-4: accumulate 8 batches
|
||||
5: 4, # Epochs 5-9: accumulate 4 batches
|
||||
10: 2 # Epochs 10+: accumulate 2 batches
|
||||
}
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[accumulator])
|
||||
```
|
||||
|
||||
### 6. StochasticWeightAveraging (SWA)
|
||||
|
||||
**Averages weights for better generalization**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import StochasticWeightAveraging
|
||||
|
||||
swa = StochasticWeightAveraging(
|
||||
swa_lrs=1e-2, # SWA learning rate
|
||||
swa_epoch_start=0.8, # Start at 80% of training
|
||||
annealing_epochs=10, # Annealing period
|
||||
annealing_strategy='cos' # 'cos' or 'linear'
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[swa])
|
||||
```
|
||||
|
||||
## Custom Callbacks
|
||||
|
||||
### Basic Custom Callback
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
class PrintingCallback(Callback):
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
print("Training is starting!")
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
print("Training is done!")
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
print(f"Epoch {trainer.current_epoch} ended")
|
||||
|
||||
# Use it
|
||||
trainer = L.Trainer(callbacks=[PrintingCallback()])
|
||||
```
|
||||
|
||||
### Advanced Custom Callback
|
||||
|
||||
```python
|
||||
class MetricsCallback(Callback):
|
||||
"""Logs custom metrics every N batches."""
|
||||
|
||||
def __init__(self, log_every_n_batches=100):
|
||||
self.log_every_n_batches = log_every_n_batches
|
||||
self.metrics = []
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if batch_idx % self.log_every_n_batches == 0:
|
||||
# Compute custom metric
|
||||
metric = self.compute_metric(outputs)
|
||||
self.metrics.append(metric)
|
||||
|
||||
# Log to Lightning
|
||||
pl_module.log('custom_metric', metric)
|
||||
|
||||
def compute_metric(self, outputs):
|
||||
# Your custom logic
|
||||
return outputs['loss'].item()
|
||||
|
||||
def state_dict(self):
|
||||
"""Save callback state in checkpoint."""
|
||||
return {'metrics': self.metrics}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Restore callback state from checkpoint."""
|
||||
self.metrics = state_dict['metrics']
|
||||
```
|
||||
|
||||
### Gradient Monitoring Callback
|
||||
|
||||
```python
|
||||
class GradientMonitorCallback(Callback):
|
||||
"""Monitor gradient norms."""
|
||||
|
||||
def on_after_backward(self, trainer, pl_module):
|
||||
# Compute gradient norm
|
||||
total_norm = 0.0
|
||||
for p in pl_module.parameters():
|
||||
if p.grad is not None:
|
||||
param_norm = p.grad.data.norm(2)
|
||||
total_norm += param_norm.item() ** 2
|
||||
total_norm = total_norm ** 0.5
|
||||
|
||||
# Log
|
||||
pl_module.log('grad_norm', total_norm)
|
||||
|
||||
# Warn if exploding
|
||||
if total_norm > 100:
|
||||
print(f"Warning: Large gradient norm: {total_norm:.2f}")
|
||||
```
|
||||
|
||||
### Model Inspection Callback
|
||||
|
||||
```python
|
||||
class ModelInspectionCallback(Callback):
|
||||
"""Inspect model activations during training."""
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
if batch_idx == 0: # First batch of epoch
|
||||
# Register hooks
|
||||
self.activations = {}
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
self.activations[name] = output.detach()
|
||||
return hook
|
||||
|
||||
# Attach to specific layers
|
||||
pl_module.model.layer1.register_forward_hook(get_activation('layer1'))
|
||||
pl_module.model.layer2.register_forward_hook(get_activation('layer2'))
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if batch_idx == 0:
|
||||
# Log activation statistics
|
||||
for name, activation in self.activations.items():
|
||||
mean = activation.mean().item()
|
||||
std = activation.std().item()
|
||||
pl_module.log(f'{name}_mean', mean)
|
||||
pl_module.log(f'{name}_std', std)
|
||||
```
|
||||
|
||||
## Callback Hooks
|
||||
|
||||
**All available hooks**:
|
||||
|
||||
```python
|
||||
class MyCallback(Callback):
|
||||
# Setup/Teardown
|
||||
def setup(self, trainer, pl_module, stage):
|
||||
"""Called at beginning of fit/test/predict."""
|
||||
pass
|
||||
|
||||
def teardown(self, trainer, pl_module, stage):
|
||||
"""Called at end of fit/test/predict."""
|
||||
pass
|
||||
|
||||
# Training
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
pass
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
pass
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
# Validation
|
||||
def on_validation_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_validation_epoch_start(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
# Test (same structure as validation)
|
||||
def on_test_start(self, trainer, pl_module):
|
||||
pass
|
||||
# ... (test_epoch_start, test_batch_start, etc.)
|
||||
|
||||
# Predict
|
||||
def on_predict_start(self, trainer, pl_module):
|
||||
pass
|
||||
# ... (predict_epoch_start, predict_batch_start, etc.)
|
||||
|
||||
# Backward
|
||||
def on_before_backward(self, trainer, pl_module, loss):
|
||||
pass
|
||||
|
||||
def on_after_backward(self, trainer, pl_module):
|
||||
pass
|
||||
|
||||
# Optimizer
|
||||
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
|
||||
pass
|
||||
|
||||
# Checkpointing
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
"""Add data to checkpoint."""
|
||||
pass
|
||||
|
||||
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
"""Restore data from checkpoint."""
|
||||
pass
|
||||
```
|
||||
|
||||
## Combining Multiple Callbacks
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
|
||||
|
||||
# Create all callbacks
|
||||
checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=3)
|
||||
early_stop = EarlyStopping(monitor='val_loss', patience=5)
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
custom_callback = MyCustomCallback()
|
||||
|
||||
# Add all to Trainer
|
||||
trainer = L.Trainer(
|
||||
callbacks=[checkpoint, early_stop, lr_monitor, custom_callback]
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
```
|
||||
|
||||
**Execution order**: Callbacks execute in the order they're added
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Keep Callbacks Independent
|
||||
|
||||
**Bad** (dependent on other callback):
|
||||
```python
|
||||
class BadCallback(Callback):
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
# Assumes ModelCheckpoint is present
|
||||
best_path = trainer.checkpoint_callback.best_model_path # Fragile!
|
||||
```
|
||||
|
||||
**Good** (self-contained):
|
||||
```python
|
||||
class GoodCallback(Callback):
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
# Find checkpoint callback if present
|
||||
for callback in trainer.callbacks:
|
||||
if isinstance(callback, ModelCheckpoint):
|
||||
best_path = callback.best_model_path
|
||||
break
|
||||
```
|
||||
|
||||
### 2. Use State Dict for Persistence
|
||||
|
||||
```python
|
||||
class StatefulCallback(Callback):
|
||||
def __init__(self):
|
||||
self.counter = 0
|
||||
self.history = []
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
self.counter += 1
|
||||
self.history.append(outputs['loss'].item())
|
||||
|
||||
def state_dict(self):
|
||||
"""Save state."""
|
||||
return {
|
||||
'counter': self.counter,
|
||||
'history': self.history
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Restore state."""
|
||||
self.counter = state_dict['counter']
|
||||
self.history = state_dict['history']
|
||||
```
|
||||
|
||||
### 3. Handle Distributed Training
|
||||
|
||||
```python
|
||||
class DistributedCallback(Callback):
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
# Only run on main process
|
||||
if trainer.is_global_zero:
|
||||
print("This only prints once in distributed training")
|
||||
|
||||
# Run on all processes
|
||||
loss = outputs['loss']
|
||||
# ... do something with loss on each GPU
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Callback API: https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html
|
||||
- Built-in callbacks: https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks
|
||||
- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/callbacks
|
||||
490
skills/mlops/pytorch-lightning/references/distributed.md
Normal file
490
skills/mlops/pytorch-lightning/references/distributed.md
Normal file
|
|
@ -0,0 +1,490 @@
|
|||
# PyTorch Lightning Distributed Training
|
||||
|
||||
## Distributed Strategies
|
||||
|
||||
Lightning supports multiple distributed strategies with a single parameter change.
|
||||
|
||||
### 1. DDP (DistributedDataParallel)
|
||||
|
||||
**Default strategy for multi-GPU**:
|
||||
|
||||
```python
|
||||
# Automatic DDP on all available GPUs
|
||||
trainer = L.Trainer(accelerator='gpu', devices=4, strategy='ddp')
|
||||
|
||||
# Or auto-detect
|
||||
trainer = L.Trainer(accelerator='gpu', devices='auto')
|
||||
```
|
||||
|
||||
**How DDP works**:
|
||||
- Replicates model on each GPU
|
||||
- Each GPU processes different batch
|
||||
- Gradients all-reduced across GPUs
|
||||
- Model weights synchronized
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
# Lightning handles spawning processes automatically
|
||||
python train.py
|
||||
```
|
||||
|
||||
**DDP Configuration**:
|
||||
```python
|
||||
from lightning.pytorch.strategies import DDPStrategy
|
||||
|
||||
strategy = DDPStrategy(
|
||||
find_unused_parameters=False, # Set True if model has unused params
|
||||
gradient_as_bucket_view=True, # Memory optimization
|
||||
static_graph=False, # Set True if graph doesn't change
|
||||
)
|
||||
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
### 2. FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**For large models (7B+ parameters)**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import FSDPStrategy
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
||||
activation_checkpointing=None, # Or specify layer types
|
||||
cpu_offload=False, # CPU offload for memory
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
strategy=strategy,
|
||||
precision='bf16' # Recommended with FSDP
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**FSDP Sharding Strategies**:
|
||||
```python
|
||||
# FULL_SHARD (most memory efficient, equivalent to ZeRO-3)
|
||||
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
|
||||
|
||||
# SHARD_GRAD_OP (less memory efficient, equivalent to ZeRO-2)
|
||||
strategy = FSDPStrategy(sharding_strategy="SHARD_GRAD_OP")
|
||||
|
||||
# NO_SHARD (no sharding, like DDP)
|
||||
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
|
||||
```
|
||||
|
||||
**Auto-wrap policy** (wrap transformer blocks):
|
||||
```python
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
import functools
|
||||
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={GPT2Block}
|
||||
)
|
||||
|
||||
strategy = FSDPStrategy(
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
activation_checkpointing_policy={GPT2Block} # Checkpoint these blocks
|
||||
)
|
||||
```
|
||||
|
||||
### 3. DeepSpeed
|
||||
|
||||
**For massive models (70B+ parameters)**:
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import DeepSpeedStrategy
|
||||
|
||||
# DeepSpeed ZeRO-3 with CPU offload
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=3, # ZeRO-3
|
||||
offload_optimizer=True, # CPU offload optimizer
|
||||
offload_parameters=True, # CPU offload parameters
|
||||
cpu_checkpointing=True, # Checkpoint to CPU
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
strategy=strategy,
|
||||
precision='bf16'
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**DeepSpeed configuration file**:
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Use config file**:
|
||||
```python
|
||||
strategy = DeepSpeedStrategy(config='deepspeed_config.json')
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
### 4. DDP Spawn
|
||||
|
||||
**Windows-compatible DDP**:
|
||||
|
||||
```python
|
||||
# Use when DDP doesn't work (e.g., Windows, Jupyter)
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=2,
|
||||
strategy='ddp_spawn' # Spawns new processes
|
||||
)
|
||||
```
|
||||
|
||||
**Note**: Slower than DDP due to process spawning overhead
|
||||
|
||||
## Multi-Node Training
|
||||
|
||||
### Setup Multi-Node Cluster
|
||||
|
||||
**Node 0 (master)**:
|
||||
```bash
|
||||
export MASTER_ADDR=192.168.1.100
|
||||
export MASTER_PORT=12355
|
||||
export WORLD_SIZE=16 # 2 nodes × 8 GPUs
|
||||
export NODE_RANK=0
|
||||
|
||||
python train.py
|
||||
```
|
||||
|
||||
**Node 1 (worker)**:
|
||||
```bash
|
||||
export MASTER_ADDR=192.168.1.100
|
||||
export MASTER_PORT=12355
|
||||
export WORLD_SIZE=16
|
||||
export NODE_RANK=1
|
||||
|
||||
python train.py
|
||||
```
|
||||
|
||||
**Training script**:
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8, # GPUs per node
|
||||
num_nodes=2, # Total nodes
|
||||
strategy='ddp'
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
### SLURM Integration
|
||||
|
||||
**SLURM job script**:
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --time=24:00:00
|
||||
|
||||
# Lightning auto-detects SLURM environment
|
||||
srun python train.py
|
||||
```
|
||||
|
||||
**Training script** (no changes needed):
|
||||
```python
|
||||
# Lightning automatically reads SLURM environment variables
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=8,
|
||||
num_nodes=4, # From SBATCH --nodes
|
||||
strategy='ddp'
|
||||
)
|
||||
```
|
||||
|
||||
### Kubernetes (KubeFlow)
|
||||
|
||||
**Training script**:
|
||||
```python
|
||||
import os
|
||||
|
||||
# Lightning auto-detects Kubernetes
|
||||
trainer = L.Trainer(
|
||||
accelerator='gpu',
|
||||
devices=int(os.getenv('WORLD_SIZE', 1)),
|
||||
strategy='ddp'
|
||||
)
|
||||
```
|
||||
|
||||
## Mixed Precision Training
|
||||
|
||||
### BF16 (A100/H100)
|
||||
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
precision='bf16', # Or 'bf16-mixed'
|
||||
accelerator='gpu'
|
||||
)
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- No gradient scaler needed
|
||||
- Same dynamic range as FP32
|
||||
- 2× speedup, 50% memory reduction
|
||||
|
||||
### FP16 (V100, older GPUs)
|
||||
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
precision='16-mixed', # Or just '16'
|
||||
accelerator='gpu'
|
||||
)
|
||||
```
|
||||
|
||||
**Automatic gradient scaling** handled by Lightning
|
||||
|
||||
### FP8 (H100)
|
||||
|
||||
```python
|
||||
# Requires transformer_engine
|
||||
# pip install transformer-engine[pytorch]
|
||||
|
||||
trainer = L.Trainer(
|
||||
precision='transformer-engine',
|
||||
accelerator='gpu'
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**: 2× faster than BF16 on H100
|
||||
|
||||
## Gradient Accumulation
|
||||
|
||||
**Simulate larger batch size**:
|
||||
|
||||
```python
|
||||
trainer = L.Trainer(
|
||||
accumulate_grad_batches=4, # Accumulate 4 batches
|
||||
precision='bf16'
|
||||
)
|
||||
|
||||
# Effective batch = batch_size × accumulate_grad_batches × num_gpus
|
||||
# Example: 32 × 4 × 8 = 1024
|
||||
```
|
||||
|
||||
**Dynamic accumulation**:
|
||||
```python
|
||||
# Accumulate more early in training
|
||||
trainer = L.Trainer(
|
||||
accumulate_grad_batches={
|
||||
0: 8, # Epochs 0-4: accumulate 8
|
||||
5: 4, # Epochs 5-9: accumulate 4
|
||||
10: 2 # Epochs 10+: accumulate 2
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Checkpointing in Distributed
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
||||
# Only rank 0 saves by default
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath='checkpoints/',
|
||||
filename='model-{epoch:02d}',
|
||||
save_top_k=3
|
||||
)
|
||||
|
||||
trainer = L.Trainer(callbacks=[checkpoint], strategy='ddp')
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
**Manual save**:
|
||||
```python
|
||||
class MyModel(L.LightningModule):
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Training...
|
||||
loss = ...
|
||||
|
||||
# Save every 1000 steps (only rank 0)
|
||||
if batch_idx % 1000 == 0 and self.trainer.is_global_zero:
|
||||
self.trainer.save_checkpoint(f'checkpoint_step_{batch_idx}.ckpt')
|
||||
|
||||
return loss
|
||||
```
|
||||
|
||||
### Load Checkpoint
|
||||
|
||||
```python
|
||||
# Resume training
|
||||
trainer = L.Trainer(strategy='ddp')
|
||||
trainer.fit(model, train_loader, ckpt_path='checkpoints/last.ckpt')
|
||||
|
||||
# Load for inference
|
||||
model = MyModel.load_from_checkpoint('checkpoints/best.ckpt')
|
||||
model.eval()
|
||||
```
|
||||
|
||||
## Strategy Comparison
|
||||
|
||||
| Strategy | Memory Efficiency | Speed | Use Case |
|
||||
|----------|------------------|-------|----------|
|
||||
| DDP | Low | Fast | Small models (<7B), single node |
|
||||
| FSDP | High | Medium | Large models (7-70B) |
|
||||
| DeepSpeed ZeRO-2 | Medium | Fast | Medium models (1-13B) |
|
||||
| DeepSpeed ZeRO-3 | Very High | Slower | Massive models (70B+) |
|
||||
| DDP Spawn | Low | Slow | Windows, debugging |
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Choose Right Strategy
|
||||
|
||||
```python
|
||||
# Model size guide
|
||||
if model_params < 1e9: # <1B
|
||||
strategy = 'ddp'
|
||||
elif model_params < 7e9: # 1-7B
|
||||
strategy = 'ddp' or DeepSpeedStrategy(stage=2)
|
||||
elif model_params < 70e9: # 7-70B
|
||||
strategy = FSDPStrategy(sharding_strategy="FULL_SHARD")
|
||||
else: # 70B+
|
||||
strategy = DeepSpeedStrategy(stage=3, offload_optimizer=True)
|
||||
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
### 2. Avoid Sync Issues
|
||||
|
||||
```python
|
||||
class MyModel(L.LightningModule):
|
||||
def training_step(self, batch, batch_idx):
|
||||
# WRONG: This runs on all GPUs independently
|
||||
if batch_idx % 100 == 0:
|
||||
self.log_something() # Logged 8 times on 8 GPUs!
|
||||
|
||||
# CORRECT: Use is_global_zero
|
||||
if batch_idx % 100 == 0 and self.trainer.is_global_zero:
|
||||
self.log_something() # Logged once
|
||||
|
||||
loss = ...
|
||||
return loss
|
||||
```
|
||||
|
||||
### 3. Efficient Data Loading
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
|
||||
# Lightning handles DistributedSampler automatically
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # 4 workers per GPU
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
|
||||
# Lightning automatically wraps with DistributedSampler in DDP
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
### 4. Reduce Communication Overhead
|
||||
|
||||
```python
|
||||
from lightning.pytorch.strategies import DDPStrategy
|
||||
|
||||
strategy = DDPStrategy(
|
||||
gradient_as_bucket_view=True, # Reduce memory copies
|
||||
static_graph=True, # If model graph doesn't change (faster)
|
||||
)
|
||||
|
||||
trainer = L.Trainer(strategy=strategy)
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: NCCL Timeout
|
||||
|
||||
**Symptom**: Training hangs with `NCCL timeout` error
|
||||
|
||||
**Solution 1**: Increase timeout
|
||||
```bash
|
||||
export NCCL_TIMEOUT=3600 # 1 hour
|
||||
python train.py
|
||||
```
|
||||
|
||||
**Solution 2**: Check network
|
||||
```bash
|
||||
# Test inter-node communication
|
||||
nvidia-smi nvlink -s
|
||||
|
||||
# Verify all nodes can ping each other
|
||||
ping <node-2-ip>
|
||||
```
|
||||
|
||||
### Issue: OOM with FSDP
|
||||
|
||||
**Solution**: Enable CPU offload
|
||||
```python
|
||||
strategy = FSDPStrategy(
|
||||
sharding_strategy="FULL_SHARD",
|
||||
cpu_offload=True # Offload to CPU
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Different Results with DDP
|
||||
|
||||
**Cause**: Different random seeds per GPU
|
||||
|
||||
**Solution**: Set seed in LightningModule
|
||||
```python
|
||||
class MyModel(L.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
L.seed_everything(42, workers=True) # Same seed everywhere
|
||||
```
|
||||
|
||||
### Issue: DeepSpeed Config Errors
|
||||
|
||||
**Solution**: Use Lightning's auto config
|
||||
```python
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=3,
|
||||
# Don't specify config file, Lightning generates automatically
|
||||
)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Distributed strategies: https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html
|
||||
- FSDP guide: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html
|
||||
- DeepSpeed: https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/deepspeed.html
|
||||
- Multi-node: https://lightning.ai/docs/pytorch/stable/clouds/cluster.html
|
||||
|
|
@ -0,0 +1,556 @@
|
|||
# Hyperparameter Tuning with PyTorch Lightning
|
||||
|
||||
## Integration with Tuning Frameworks
|
||||
|
||||
Lightning integrates seamlessly with popular hyperparameter tuning libraries.
|
||||
|
||||
### 1. Ray Tune Integration
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install ray[tune]
|
||||
pip install lightning
|
||||
```
|
||||
|
||||
**Basic Ray Tune example**:
|
||||
|
||||
```python
|
||||
import lightning as L
|
||||
from ray import tune
|
||||
from ray.tune.integration.pytorch_lightning import TuneReportCallback
|
||||
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, lr, batch_size):
|
||||
super().__init__()
|
||||
self.lr = lr
|
||||
self.batch_size = batch_size
|
||||
self.model = nn.Sequential(nn.Linear(10, 128), nn.ReLU(), nn.Linear(128, 1))
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.model(batch).mean()
|
||||
self.log('train_loss', loss)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
val_loss = self.model(batch).mean()
|
||||
self.log('val_loss', val_loss)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
|
||||
def train_fn(config):
|
||||
"""Training function for Ray Tune."""
|
||||
model = LitModel(lr=config["lr"], batch_size=config["batch_size"])
|
||||
|
||||
# Add callback to report metrics to Tune
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
callbacks=[TuneReportCallback({"loss": "val_loss"}, on="validation_end")]
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Define search space
|
||||
config = {
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": tune.choice([16, 32, 64, 128])
|
||||
}
|
||||
|
||||
# Run hyperparameter search
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=20, # 20 trials
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
|
||||
# Best hyperparameters
|
||||
best_config = analysis.get_best_config(metric="loss", mode="min")
|
||||
print(f"Best config: {best_config}")
|
||||
```
|
||||
|
||||
**Advanced: Population-Based Training (PBT)**:
|
||||
|
||||
```python
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
# PBT scheduler
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr='training_iteration',
|
||||
metric='val_loss',
|
||||
mode='min',
|
||||
perturbation_interval=5, # Perturb every 5 epochs
|
||||
hyperparam_mutations={
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": [16, 32, 64, 128]
|
||||
}
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=8, # Population size
|
||||
scheduler=scheduler,
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Optuna Integration
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install optuna
|
||||
pip install optuna-integration
|
||||
```
|
||||
|
||||
**Optuna example**:
|
||||
|
||||
```python
|
||||
import optuna
|
||||
from optuna.integration import PyTorchLightningPruningCallback
|
||||
|
||||
def objective(trial):
|
||||
# Suggest hyperparameters
|
||||
lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
|
||||
batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
|
||||
n_layers = trial.suggest_int('n_layers', 1, 3)
|
||||
hidden_size = trial.suggest_int('hidden_size', 64, 512, step=64)
|
||||
|
||||
# Create model
|
||||
model = LitModel(lr=lr, n_layers=n_layers, hidden_size=hidden_size)
|
||||
|
||||
# Pruning callback (early stopping for bad trials)
|
||||
pruning_callback = PyTorchLightningPruningCallback(trial, monitor="val_loss")
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=20,
|
||||
callbacks=[pruning_callback],
|
||||
enable_progress_bar=False,
|
||||
logger=False
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
return trainer.callback_metrics["val_loss"].item()
|
||||
|
||||
# Create study
|
||||
study = optuna.create_study(
|
||||
direction='minimize',
|
||||
pruner=optuna.pruners.MedianPruner() # Prune bad trials early
|
||||
)
|
||||
|
||||
# Optimize
|
||||
study.optimize(objective, n_trials=50, timeout=3600)
|
||||
|
||||
# Best params
|
||||
print(f"Best trial: {study.best_trial.params}")
|
||||
print(f"Best value: {study.best_value}")
|
||||
|
||||
# Visualization
|
||||
optuna.visualization.plot_optimization_history(study).show()
|
||||
optuna.visualization.plot_param_importances(study).show()
|
||||
```
|
||||
|
||||
**Optuna with distributed training**:
|
||||
|
||||
```python
|
||||
import optuna
|
||||
|
||||
# Shared database for distributed optimization
|
||||
storage = optuna.storages.RDBStorage(
|
||||
url='postgresql://user:pass@localhost/optuna'
|
||||
)
|
||||
|
||||
study = optuna.create_study(
|
||||
study_name='distributed_study',
|
||||
storage=storage,
|
||||
load_if_exists=True,
|
||||
direction='minimize'
|
||||
)
|
||||
|
||||
# Run on multiple machines
|
||||
study.optimize(objective, n_trials=50)
|
||||
```
|
||||
|
||||
### 3. Weights & Biases (WandB) Sweeps
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install wandb
|
||||
```
|
||||
|
||||
**WandB sweep config** (`sweep.yaml`):
|
||||
```yaml
|
||||
program: train.py
|
||||
method: bayes
|
||||
metric:
|
||||
name: val_loss
|
||||
goal: minimize
|
||||
parameters:
|
||||
lr:
|
||||
distribution: log_uniform_values
|
||||
min: 0.00001
|
||||
max: 0.1
|
||||
batch_size:
|
||||
values: [16, 32, 64, 128]
|
||||
optimizer:
|
||||
values: ['adam', 'sgd', 'adamw']
|
||||
dropout:
|
||||
distribution: uniform
|
||||
min: 0.0
|
||||
max: 0.5
|
||||
```
|
||||
|
||||
**Training script** (`train.py`):
|
||||
```python
|
||||
import wandb
|
||||
import lightning as L
|
||||
from lightning.pytorch.loggers import WandbLogger
|
||||
|
||||
def train():
|
||||
# Initialize wandb
|
||||
wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
# Create model with sweep params
|
||||
model = LitModel(
|
||||
lr=config.lr,
|
||||
batch_size=config.batch_size,
|
||||
optimizer=config.optimizer,
|
||||
dropout=config.dropout
|
||||
)
|
||||
|
||||
# WandB logger
|
||||
wandb_logger = WandbLogger(project='hyperparameter-sweep')
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=20,
|
||||
logger=wandb_logger
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
```
|
||||
|
||||
**Launch sweep**:
|
||||
```bash
|
||||
# Initialize sweep
|
||||
wandb sweep sweep.yaml
|
||||
# Output: wandb: Created sweep with ID: abc123
|
||||
|
||||
# Run agent (can run on multiple machines)
|
||||
wandb agent your-entity/your-project/abc123
|
||||
```
|
||||
|
||||
### 4. Hyperopt Integration
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install hyperopt
|
||||
```
|
||||
|
||||
**Hyperopt example**:
|
||||
|
||||
```python
|
||||
from hyperopt import hp, fmin, tpe, Trials
|
||||
|
||||
def objective(params):
|
||||
model = LitModel(
|
||||
lr=params['lr'],
|
||||
batch_size=int(params['batch_size']),
|
||||
hidden_size=int(params['hidden_size'])
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=10,
|
||||
enable_progress_bar=False,
|
||||
logger=False
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
# Return loss (minimize)
|
||||
return trainer.callback_metrics["val_loss"].item()
|
||||
|
||||
# Define search space
|
||||
space = {
|
||||
'lr': hp.loguniform('lr', np.log(1e-5), np.log(1e-1)),
|
||||
'batch_size': hp.quniform('batch_size', 16, 128, 16),
|
||||
'hidden_size': hp.quniform('hidden_size', 64, 512, 64)
|
||||
}
|
||||
|
||||
# Optimize
|
||||
trials = Trials()
|
||||
best = fmin(
|
||||
fn=objective,
|
||||
space=space,
|
||||
algo=tpe.suggest, # Tree-structured Parzen Estimator
|
||||
max_evals=50,
|
||||
trials=trials
|
||||
)
|
||||
|
||||
print(f"Best hyperparameters: {best}")
|
||||
```
|
||||
|
||||
## Built-In Lightning Tuning
|
||||
|
||||
### Auto Learning Rate Finder
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, lr=1e-3):
|
||||
super().__init__()
|
||||
self.lr = lr
|
||||
self.model = nn.Linear(10, 1)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.model(batch).mean()
|
||||
return loss
|
||||
|
||||
# Find optimal learning rate
|
||||
model = LitModel()
|
||||
trainer = L.Trainer(auto_lr_find=True)
|
||||
|
||||
# This runs LR finder before training
|
||||
trainer.tune(model, train_loader)
|
||||
|
||||
# Or manually
|
||||
from lightning.pytorch.tuner import Tuner
|
||||
tuner = Tuner(trainer)
|
||||
lr_finder = tuner.lr_find(model, train_loader)
|
||||
|
||||
# Plot results
|
||||
fig = lr_finder.plot(suggest=True)
|
||||
fig.show()
|
||||
|
||||
# Get suggested LR
|
||||
suggested_lr = lr_finder.suggestion()
|
||||
print(f"Suggested LR: {suggested_lr}")
|
||||
|
||||
# Update model
|
||||
model.lr = suggested_lr
|
||||
|
||||
# Train with optimal LR
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
### Auto Batch Size Finder
|
||||
|
||||
```python
|
||||
class LitModel(L.LightningModule):
|
||||
def __init__(self, batch_size=32):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.model = nn.Linear(10, 1)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset, batch_size=self.batch_size)
|
||||
|
||||
model = LitModel()
|
||||
trainer = L.Trainer(auto_scale_batch_size='binsearch')
|
||||
|
||||
# Find optimal batch size
|
||||
trainer.tune(model)
|
||||
|
||||
print(f"Optimal batch size: {model.batch_size}")
|
||||
|
||||
# Train with optimal batch size
|
||||
trainer.fit(model, train_loader)
|
||||
```
|
||||
|
||||
## Advanced Tuning Strategies
|
||||
|
||||
### 1. Multi-Fidelity Optimization (Successive Halving)
|
||||
|
||||
```python
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
|
||||
# ASHA: Asynchronous Successive Halving Algorithm
|
||||
scheduler = ASHAScheduler(
|
||||
max_t=100, # Max epochs
|
||||
grace_period=10, # Min epochs before stopping
|
||||
reduction_factor=2 # Halve resources each round
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=64,
|
||||
scheduler=scheduler,
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
```
|
||||
|
||||
**How it works**:
|
||||
- Start 64 trials
|
||||
- After 10 epochs, stop bottom 50% (32 trials remain)
|
||||
- After 20 epochs, stop bottom 50% (16 trials remain)
|
||||
- After 40 epochs, stop bottom 50% (8 trials remain)
|
||||
- After 80 epochs, stop bottom 50% (4 trials remain)
|
||||
- Run remaining 4 trials to completion (100 epochs)
|
||||
|
||||
### 2. Bayesian Optimization
|
||||
|
||||
```python
|
||||
from ray.tune.search.bayesopt import BayesOptSearch
|
||||
|
||||
search = BayesOptSearch(
|
||||
metric="val_loss",
|
||||
mode="min"
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=50,
|
||||
search_alg=search,
|
||||
resources_per_trial={"gpu": 1}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Grid Search
|
||||
|
||||
```python
|
||||
from ray import tune
|
||||
|
||||
# Exhaustive grid search
|
||||
config = {
|
||||
"lr": tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2]),
|
||||
"batch_size": tune.grid_search([16, 32, 64, 128]),
|
||||
"optimizer": tune.grid_search(['adam', 'sgd', 'adamw'])
|
||||
}
|
||||
|
||||
# Total trials: 4 × 4 × 3 = 48
|
||||
analysis = tune.run(train_fn, config=config)
|
||||
```
|
||||
|
||||
### 4. Random Search
|
||||
|
||||
```python
|
||||
config = {
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": tune.choice([16, 32, 64, 128]),
|
||||
"dropout": tune.uniform(0.0, 0.5),
|
||||
"hidden_size": tune.randint(64, 512)
|
||||
}
|
||||
|
||||
# Random sampling
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
num_samples=100 # 100 random samples
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Simple
|
||||
|
||||
```python
|
||||
# Phase 1: Coarse search (fast)
|
||||
coarse_config = {
|
||||
"lr": tune.loguniform(1e-5, 1e-1),
|
||||
"batch_size": tune.choice([32, 64])
|
||||
}
|
||||
coarse_analysis = tune.run(train_fn, config=coarse_config, num_samples=10, max_epochs=5)
|
||||
|
||||
# Phase 2: Fine-tune around best (slow)
|
||||
best_lr = coarse_analysis.best_config["lr"]
|
||||
fine_config = {
|
||||
"lr": tune.uniform(best_lr * 0.5, best_lr * 2),
|
||||
"batch_size": tune.choice([16, 32, 64, 128])
|
||||
}
|
||||
fine_analysis = tune.run(train_fn, config=fine_config, num_samples=20, max_epochs=20)
|
||||
```
|
||||
|
||||
### 2. Use Checkpointing
|
||||
|
||||
```python
|
||||
def train_fn(config, checkpoint_dir=None):
|
||||
model = LitModel(lr=config["lr"])
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=100,
|
||||
callbacks=[
|
||||
TuneReportCheckpointCallback(
|
||||
metrics={"loss": "val_loss"},
|
||||
filename="checkpoint",
|
||||
on="validation_end"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Resume from checkpoint if exists
|
||||
ckpt_path = None
|
||||
if checkpoint_dir:
|
||||
ckpt_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
|
||||
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
|
||||
```
|
||||
|
||||
### 3. Monitor Resource Usage
|
||||
|
||||
```python
|
||||
import GPUtil
|
||||
|
||||
def train_fn(config):
|
||||
# Before training
|
||||
GPUs = GPUtil.getGPUs()
|
||||
print(f"GPU memory before: {GPUs[0].memoryUsed} MB")
|
||||
|
||||
# Train
|
||||
model = LitModel(lr=config["lr"], batch_size=config["batch_size"])
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# After training
|
||||
GPUs = GPUtil.getGPUs()
|
||||
print(f"GPU memory after: {GPUs[0].memoryUsed} MB")
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: Trials Running Out of Memory
|
||||
|
||||
**Solution**: Reduce concurrent trials or batch size
|
||||
```python
|
||||
analysis = tune.run(
|
||||
train_fn,
|
||||
config=config,
|
||||
resources_per_trial={"gpu": 0.5}, # 2 trials per GPU
|
||||
max_concurrent_trials=2 # Limit concurrent trials
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow Hyperparameter Search
|
||||
|
||||
**Solution**: Use early stopping scheduler
|
||||
```python
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
|
||||
scheduler = ASHAScheduler(
|
||||
max_t=100,
|
||||
grace_period=5, # Stop bad trials after 5 epochs
|
||||
reduction_factor=3
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Can't Reproduce Best Trial
|
||||
|
||||
**Solution**: Set seeds in training function
|
||||
```python
|
||||
def train_fn(config):
|
||||
L.seed_everything(42, workers=True)
|
||||
# Rest of training...
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Ray Tune + Lightning: https://docs.ray.io/en/latest/tune/examples/tune-pytorch-lightning.html
|
||||
- Optuna: https://optuna.readthedocs.io/
|
||||
- WandB Sweeps: https://docs.wandb.ai/guides/sweeps
|
||||
- Lightning Tuner: https://lightning.ai/docs/pytorch/stable/tuning.html
|
||||
222
skills/mlops/simpo/SKILL.md
Normal file
222
skills/mlops/simpo/SKILL.md
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
---
|
||||
name: simpo-training
|
||||
description: Simple Preference Optimization for LLM alignment. Reference-free alternative to DPO with better performance (+6.4 points on AlpacaEval 2.0). No reference model needed, more efficient than DPO. Use for preference alignment when want simpler, faster training than DPO/PPO.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [torch, transformers, datasets, trl, accelerate]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, SimPO, Preference Optimization, Alignment, DPO Alternative, Reference-Free, LLM Alignment, Efficient Training]
|
||||
|
||||
---
|
||||
|
||||
# SimPO - Simple Preference Optimization
|
||||
|
||||
## Quick start
|
||||
|
||||
SimPO is a reference-free preference optimization method that outperforms DPO without needing a reference model.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# Create environment
|
||||
conda create -n simpo python=3.10 && conda activate simpo
|
||||
|
||||
# Install PyTorch 2.2.2
|
||||
# Visit: https://pytorch.org/get-started/locally/
|
||||
|
||||
# Install alignment-handbook
|
||||
git clone https://github.com/huggingface/alignment-handbook.git
|
||||
cd alignment-handbook
|
||||
python -m pip install .
|
||||
|
||||
# Install Flash Attention 2
|
||||
python -m pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
**Training** (Mistral 7B):
|
||||
```bash
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch \
|
||||
--config_file accelerate_configs/deepspeed_zero3.yaml \
|
||||
scripts/run_simpo.py \
|
||||
training_configs/mistral-7b-base-simpo.yaml
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Train from base model (Mistral 7B)
|
||||
|
||||
**Config** (`mistral-7b-base-simpo.yaml`):
|
||||
```yaml
|
||||
# Model
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
torch_dtype: bfloat16
|
||||
|
||||
# Dataset
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
|
||||
# SimPO hyperparameters
|
||||
beta: 2.0 # Reward scaling (2.0-10.0)
|
||||
gamma_beta_ratio: 0.5 # Target margin (0-1)
|
||||
loss_type: sigmoid # sigmoid or hinge
|
||||
sft_weight: 0.0 # Optional SFT regularization
|
||||
|
||||
# Training
|
||||
learning_rate: 5e-7 # Critical: 3e-7 to 1e-6
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
|
||||
# Output
|
||||
output_dir: ./outputs/mistral-7b-simpo
|
||||
```
|
||||
|
||||
**Launch training**:
|
||||
```bash
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
|
||||
scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml
|
||||
```
|
||||
|
||||
### Workflow 2: Fine-tune instruct model (Llama 3 8B)
|
||||
|
||||
**Config** (`llama3-8b-instruct-simpo.yaml`):
|
||||
```yaml
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
dataset_mixer:
|
||||
argilla/ultrafeedback-binarized-preferences-cleaned: 1.0
|
||||
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5
|
||||
learning_rate: 5e-7
|
||||
sft_weight: 0.1 # Add SFT loss to preserve capabilities
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
output_dir: ./outputs/llama3-8b-simpo
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml \
|
||||
scripts/run_simpo.py training_configs/llama3-8b-instruct-simpo.yaml
|
||||
```
|
||||
|
||||
### Workflow 3: Reasoning-intensive tasks (lower LR)
|
||||
|
||||
**For math/code tasks**:
|
||||
```yaml
|
||||
model_name_or_path: deepseek-ai/deepseek-math-7b-base
|
||||
|
||||
dataset_mixer:
|
||||
argilla/distilabel-math-preference-dpo: 1.0
|
||||
|
||||
beta: 5.0 # Higher for stronger signal
|
||||
gamma_beta_ratio: 0.7 # Larger margin
|
||||
learning_rate: 3e-7 # Lower LR for reasoning
|
||||
sft_weight: 0.0
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use SimPO when**:
|
||||
- Want simpler training than DPO (no reference model)
|
||||
- Have preference data (chosen/rejected pairs)
|
||||
- Need better performance than DPO
|
||||
- Limited compute resources
|
||||
- Single-node training sufficient
|
||||
|
||||
**Algorithm selection**:
|
||||
- **SimPO**: Simplest, best performance, no reference model
|
||||
- **DPO**: Need reference model baseline, more conservative
|
||||
- **PPO**: Maximum control, need reward model, complex setup
|
||||
- **GRPO**: Memory-efficient RL, no critic
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **OpenRLHF**: Multi-node distributed training, PPO/GRPO
|
||||
- **TRL**: Need multiple methods in one framework
|
||||
- **DPO**: Established baseline comparison
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Loss divergence**
|
||||
|
||||
Reduce learning rate:
|
||||
```yaml
|
||||
learning_rate: 3e-7 # Reduce from 5e-7
|
||||
```
|
||||
|
||||
Reduce beta:
|
||||
```yaml
|
||||
beta: 1.0 # Reduce from 2.0
|
||||
```
|
||||
|
||||
**Issue: Model forgets capabilities**
|
||||
|
||||
Add SFT regularization:
|
||||
```yaml
|
||||
sft_weight: 0.1 # Add SFT loss component
|
||||
```
|
||||
|
||||
**Issue: Poor preference separation**
|
||||
|
||||
Increase beta and margin:
|
||||
```yaml
|
||||
beta: 5.0 # Increase from 2.0
|
||||
gamma_beta_ratio: 0.8 # Increase from 0.5
|
||||
```
|
||||
|
||||
**Issue: OOM during training**
|
||||
|
||||
Reduce batch size:
|
||||
```yaml
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16 # Maintain effective batch
|
||||
```
|
||||
|
||||
Enable gradient checkpointing:
|
||||
```yaml
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Loss functions**: See [references/loss-functions.md](references/loss-functions.md) for sigmoid vs hinge loss, mathematical formulations, and when to use each.
|
||||
|
||||
**Hyperparameter tuning**: See [references/hyperparameters.md](references/hyperparameters.md) for beta, gamma, learning rate selection guide, and model-size-specific recommendations.
|
||||
|
||||
**Dataset preparation**: See [references/datasets.md](references/datasets.md) for preference data formats, quality filtering, and custom dataset creation.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA A100/H100 recommended
|
||||
- **VRAM**:
|
||||
- 7B model: 1× A100 40GB (DeepSpeed ZeRO-3)
|
||||
- 8B model: 2× A100 40GB
|
||||
- 70B model: 8× A100 80GB
|
||||
- **Single-node**: DeepSpeed ZeRO-3 sufficient
|
||||
- **Mixed precision**: BF16 recommended
|
||||
|
||||
**Memory optimization**:
|
||||
- DeepSpeed ZeRO-3 (default config)
|
||||
- Gradient checkpointing
|
||||
- Flash Attention 2
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper: https://arxiv.org/abs/2405.14734 (NeurIPS 2024)
|
||||
- GitHub: https://github.com/princeton-nlp/SimPO
|
||||
- Models: https://huggingface.co/princeton-nlp
|
||||
- Alignment Handbook: https://github.com/huggingface/alignment-handbook
|
||||
|
||||
|
||||
|
||||
478
skills/mlops/simpo/references/datasets.md
Normal file
478
skills/mlops/simpo/references/datasets.md
Normal file
|
|
@ -0,0 +1,478 @@
|
|||
# Datasets
|
||||
|
||||
Complete guide to preference datasets for SimPO training.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
### Required Fields
|
||||
|
||||
Preference datasets must contain:
|
||||
```json
|
||||
{
|
||||
"prompt": "User question or instruction",
|
||||
"chosen": "Better/preferred response",
|
||||
"rejected": "Worse/rejected response"
|
||||
}
|
||||
```
|
||||
|
||||
**Alternative field names** (auto-detected):
|
||||
- `prompt` → `question`, `instruction`, `input`
|
||||
- `chosen` → `response_chosen`, `winner`, `preferred`
|
||||
- `rejected` → `response_rejected`, `loser`
|
||||
|
||||
### Example Entry
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "Explain quantum computing in simple terms.",
|
||||
"chosen": "Quantum computing uses quantum bits (qubits) that can exist in multiple states simultaneously through superposition. This allows quantum computers to process many possibilities at once, making them potentially much faster than classical computers for specific tasks like cryptography and optimization.",
|
||||
"rejected": "It's like regular computing but quantum."
|
||||
}
|
||||
```
|
||||
|
||||
## Popular Datasets
|
||||
|
||||
### 1. UltraFeedback (Recommended)
|
||||
|
||||
**HuggingFaceH4/ultrafeedback_binarized**:
|
||||
- **Size**: 60K preference pairs
|
||||
- **Quality**: High (GPT-4 annotations)
|
||||
- **Domain**: General instruction following
|
||||
- **Format**: Clean, ready-to-use
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
dataset_splits:
|
||||
- train_prefs
|
||||
- test_prefs
|
||||
```
|
||||
|
||||
### 2. Argilla UltraFeedback (Cleaned)
|
||||
|
||||
**argilla/ultrafeedback-binarized-preferences-cleaned**:
|
||||
- **Size**: 50K pairs (filtered)
|
||||
- **Quality**: Very high (deduped, cleaned)
|
||||
- **Domain**: General
|
||||
- **Format**: Clean
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
argilla/ultrafeedback-binarized-preferences-cleaned: 1.0
|
||||
```
|
||||
|
||||
### 3. Distilabel Math
|
||||
|
||||
**argilla/distilabel-math-preference-dpo**:
|
||||
- **Size**: 30K pairs
|
||||
- **Quality**: High (GSM8K, MATH)
|
||||
- **Domain**: Math reasoning
|
||||
- **Format**: Math-specific
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
argilla/distilabel-math-preference-dpo: 1.0
|
||||
```
|
||||
|
||||
### 4. HelpSteer
|
||||
|
||||
**nvidia/HelpSteer**:
|
||||
- **Size**: 38K samples
|
||||
- **Quality**: High (human ratings)
|
||||
- **Domain**: Helpfulness alignment
|
||||
- **Format**: Multi-attribute ratings
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
nvidia/HelpSteer: 1.0
|
||||
```
|
||||
|
||||
### 5. Anthropic HH-RLHF
|
||||
|
||||
**Anthropic/hh-rlhf**:
|
||||
- **Size**: 161K samples
|
||||
- **Quality**: High (human preferences)
|
||||
- **Domain**: Harmless + helpful
|
||||
- **Format**: Conversational
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
Anthropic/hh-rlhf: 1.0
|
||||
```
|
||||
|
||||
## Dataset Mixing
|
||||
|
||||
### Multiple Datasets
|
||||
|
||||
**Equal mix**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 0.5
|
||||
Anthropic/hh-rlhf: 0.5
|
||||
```
|
||||
|
||||
**Weighted mix**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 0.7
|
||||
argilla/distilabel-math-preference-dpo: 0.2
|
||||
nvidia/HelpSteer: 0.1
|
||||
```
|
||||
|
||||
**Domain-specific emphasis**:
|
||||
```yaml
|
||||
# 80% general + 20% math
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 0.8
|
||||
argilla/distilabel-math-preference-dpo: 0.2
|
||||
```
|
||||
|
||||
## Data Quality
|
||||
|
||||
### Quality Indicators
|
||||
|
||||
**Good preference data**:
|
||||
- ✅ Clear quality difference between chosen/rejected
|
||||
- ✅ Diverse prompts
|
||||
- ✅ Minimal noise/annotation errors
|
||||
- ✅ Appropriate difficulty level
|
||||
|
||||
**Poor preference data**:
|
||||
- ❌ Ambiguous preferences
|
||||
- ❌ Repetitive prompts
|
||||
- ❌ Annotation noise
|
||||
- ❌ Too easy/hard prompts
|
||||
|
||||
### Quality Filtering
|
||||
|
||||
**Filter by length difference**:
|
||||
```python
|
||||
def filter_by_length(example):
|
||||
chosen_len = len(example['chosen'].split())
|
||||
rejected_len = len(example['rejected'].split())
|
||||
# Reject if chosen is much shorter (potential low-effort)
|
||||
return chosen_len >= rejected_len * 0.5
|
||||
|
||||
dataset = dataset.filter(filter_by_length)
|
||||
```
|
||||
|
||||
**Filter by diversity**:
|
||||
```python
|
||||
seen_prompts = set()
|
||||
|
||||
def filter_duplicates(example):
|
||||
prompt = example['prompt']
|
||||
if prompt in seen_prompts:
|
||||
return False
|
||||
seen_prompts.add(prompt)
|
||||
return True
|
||||
|
||||
dataset = dataset.filter(filter_duplicates)
|
||||
```
|
||||
|
||||
## Custom Dataset Creation
|
||||
|
||||
### Format 1: JSON Lines
|
||||
|
||||
**File** (`preferences.jsonl`):
|
||||
```jsonl
|
||||
{"prompt": "What is Python?", "chosen": "Python is a high-level programming language...", "rejected": "It's a snake."}
|
||||
{"prompt": "Explain AI.", "chosen": "AI refers to systems that can...", "rejected": "It's computers that think."}
|
||||
```
|
||||
|
||||
**Load**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
json:
|
||||
data_files: preferences.jsonl
|
||||
```
|
||||
|
||||
### Format 2: HuggingFace Dataset
|
||||
|
||||
**Create from dict**:
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
data = {
|
||||
"prompt": ["What is Python?", "Explain AI."],
|
||||
"chosen": ["Python is...", "AI refers to..."],
|
||||
"rejected": ["It's a snake.", "It's computers..."]
|
||||
}
|
||||
|
||||
dataset = Dataset.from_dict(data)
|
||||
dataset.push_to_hub("username/my-preferences")
|
||||
```
|
||||
|
||||
**Use in config**:
|
||||
```yaml
|
||||
dataset_mixer:
|
||||
username/my-preferences: 1.0
|
||||
```
|
||||
|
||||
### Format 3: ChatML
|
||||
|
||||
**For conversational data**:
|
||||
```json
|
||||
{
|
||||
"prompt": [
|
||||
{"role": "user", "content": "What is quantum computing?"}
|
||||
],
|
||||
"chosen": [
|
||||
{"role": "assistant", "content": "Quantum computing uses qubits..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "assistant", "content": "It's like regular computing but quantum."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Apply chat template**:
|
||||
```yaml
|
||||
dataset_text_field: null # Will apply chat template
|
||||
```
|
||||
|
||||
## Synthetic Data Generation
|
||||
|
||||
### Using GPT-4
|
||||
|
||||
**Prompt template**:
|
||||
```
|
||||
Given the following question:
|
||||
{prompt}
|
||||
|
||||
Generate two responses:
|
||||
1. A high-quality, detailed response (chosen)
|
||||
2. A low-quality, brief response (rejected)
|
||||
|
||||
Format as JSON with "chosen" and "rejected" fields.
|
||||
```
|
||||
|
||||
**Example code**:
|
||||
```python
|
||||
import openai
|
||||
|
||||
def generate_pair(prompt):
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-4",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Given: {prompt}\n\nGenerate chosen/rejected pair in JSON."
|
||||
}]
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
|
||||
# Generate dataset
|
||||
prompts = load_prompts()
|
||||
dataset = [generate_pair(p) for p in prompts]
|
||||
```
|
||||
|
||||
### Using Local Model
|
||||
|
||||
**With vLLM**:
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-70B-Instruct")
|
||||
|
||||
def generate_variations(prompt):
|
||||
# Generate multiple completions
|
||||
outputs = llm.generate(
|
||||
[prompt] * 4,
|
||||
sampling_params={
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 512
|
||||
}
|
||||
)
|
||||
|
||||
# Select best/worst
|
||||
chosen = max(outputs, key=lambda x: len(x.outputs[0].text))
|
||||
rejected = min(outputs, key=lambda x: len(x.outputs[0].text))
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"chosen": chosen.outputs[0].text,
|
||||
"rejected": rejected.outputs[0].text
|
||||
}
|
||||
```
|
||||
|
||||
## Data Preprocessing
|
||||
|
||||
### Truncation
|
||||
|
||||
**Limit sequence length**:
|
||||
```yaml
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 512
|
||||
max_length: 1024 # Total
|
||||
```
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
def truncate_example(example):
|
||||
tokenizer.truncation_side = "left" # For prompts
|
||||
prompt_tokens = tokenizer(
|
||||
example['prompt'],
|
||||
max_length=512,
|
||||
truncation=True
|
||||
)
|
||||
|
||||
tokenizer.truncation_side = "right" # For completions
|
||||
chosen_tokens = tokenizer(
|
||||
example['chosen'],
|
||||
max_length=512,
|
||||
truncation=True
|
||||
)
|
||||
|
||||
return {
|
||||
"prompt": tokenizer.decode(prompt_tokens['input_ids']),
|
||||
"chosen": tokenizer.decode(chosen_tokens['input_ids'])
|
||||
}
|
||||
|
||||
dataset = dataset.map(truncate_example)
|
||||
```
|
||||
|
||||
### Deduplication
|
||||
|
||||
**Remove exact duplicates**:
|
||||
```python
|
||||
dataset = dataset.unique('prompt')
|
||||
```
|
||||
|
||||
**Remove near-duplicates** (MinHash):
|
||||
```python
|
||||
from datasketch import MinHash, MinHashLSH
|
||||
|
||||
def deduplicate_lsh(dataset, threshold=0.8):
|
||||
lsh = MinHashLSH(threshold=threshold, num_perm=128)
|
||||
seen = []
|
||||
|
||||
for i, example in enumerate(dataset):
|
||||
m = MinHash(num_perm=128)
|
||||
for word in example['prompt'].split():
|
||||
m.update(word.encode('utf8'))
|
||||
|
||||
if not lsh.query(m):
|
||||
lsh.insert(i, m)
|
||||
seen.append(example)
|
||||
|
||||
return Dataset.from_list(seen)
|
||||
|
||||
dataset = deduplicate_lsh(dataset)
|
||||
```
|
||||
|
||||
## Data Augmentation
|
||||
|
||||
### Paraphrasing Prompts
|
||||
|
||||
```python
|
||||
def paraphrase_prompt(example):
|
||||
# Use paraphrasing model
|
||||
paraphrased = paraphrase_model(example['prompt'])
|
||||
|
||||
return [
|
||||
example, # Original
|
||||
{
|
||||
"prompt": paraphrased,
|
||||
"chosen": example['chosen'],
|
||||
"rejected": example['rejected']
|
||||
}
|
||||
]
|
||||
|
||||
dataset = dataset.map(paraphrase_prompt, batched=False, remove_columns=[])
|
||||
```
|
||||
|
||||
### Difficulty Balancing
|
||||
|
||||
**Mix easy/medium/hard**:
|
||||
```python
|
||||
def categorize_difficulty(example):
|
||||
prompt_len = len(example['prompt'].split())
|
||||
if prompt_len < 20:
|
||||
return "easy"
|
||||
elif prompt_len < 50:
|
||||
return "medium"
|
||||
else:
|
||||
return "hard"
|
||||
|
||||
dataset = dataset.map(lambda x: {"difficulty": categorize_difficulty(x)})
|
||||
|
||||
# Sample balanced dataset
|
||||
easy = dataset.filter(lambda x: x['difficulty'] == 'easy').shuffle().select(range(1000))
|
||||
medium = dataset.filter(lambda x: x['difficulty'] == 'medium').shuffle().select(range(1000))
|
||||
hard = dataset.filter(lambda x: x['difficulty'] == 'hard').shuffle().select(range(1000))
|
||||
|
||||
balanced = concatenate_datasets([easy, medium, hard]).shuffle()
|
||||
```
|
||||
|
||||
## Dataset Statistics
|
||||
|
||||
### Compute Stats
|
||||
|
||||
```python
|
||||
def compute_stats(dataset):
|
||||
prompt_lens = [len(x['prompt'].split()) for x in dataset]
|
||||
chosen_lens = [len(x['chosen'].split()) for x in dataset]
|
||||
rejected_lens = [len(x['rejected'].split()) for x in dataset]
|
||||
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
print(f"Avg prompt length: {np.mean(prompt_lens):.1f} words")
|
||||
print(f"Avg chosen length: {np.mean(chosen_lens):.1f} words")
|
||||
print(f"Avg rejected length: {np.mean(rejected_lens):.1f} words")
|
||||
print(f"Chosen > Rejected: {sum(c > r for c, r in zip(chosen_lens, rejected_lens)) / len(dataset):.1%}")
|
||||
|
||||
compute_stats(dataset)
|
||||
```
|
||||
|
||||
**Expected output**:
|
||||
```
|
||||
Dataset size: 50000
|
||||
Avg prompt length: 45.2 words
|
||||
Avg chosen length: 180.5 words
|
||||
Avg rejected length: 120.3 words
|
||||
Chosen > Rejected: 85.2%
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Data Quality Over Quantity
|
||||
|
||||
- **Prefer**: 10K high-quality pairs
|
||||
- **Over**: 100K noisy pairs
|
||||
|
||||
### 2. Clear Preference Signals
|
||||
|
||||
- Chosen should be noticeably better
|
||||
- Avoid marginal differences
|
||||
- Remove ambiguous pairs
|
||||
|
||||
### 3. Domain Matching
|
||||
|
||||
- Match dataset domain to target use case
|
||||
- Mix datasets for broader coverage
|
||||
- Include safety-filtered data
|
||||
|
||||
### 4. Validate Before Training
|
||||
|
||||
```python
|
||||
# Sample 10 random examples
|
||||
samples = dataset.shuffle().select(range(10))
|
||||
|
||||
for ex in samples:
|
||||
print(f"Prompt: {ex['prompt']}")
|
||||
print(f"Chosen: {ex['chosen'][:100]}...")
|
||||
print(f"Rejected: {ex['rejected'][:100]}...")
|
||||
print(f"Preference clear: {'✓' if len(ex['chosen']) > len(ex['rejected']) else '?'}")
|
||||
print()
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- HuggingFace Datasets: https://huggingface.co/datasets
|
||||
- Alignment Handbook: https://github.com/huggingface/alignment-handbook
|
||||
- UltraFeedback: https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized
|
||||
452
skills/mlops/simpo/references/hyperparameters.md
Normal file
452
skills/mlops/simpo/references/hyperparameters.md
Normal file
|
|
@ -0,0 +1,452 @@
|
|||
# Hyperparameters
|
||||
|
||||
Complete guide to SimPO hyperparameter selection and tuning.
|
||||
|
||||
## Overview
|
||||
|
||||
Key hyperparameters in SimPO:
|
||||
1. **Learning Rate** - Most critical
|
||||
2. **Beta (β)** - Reward scaling
|
||||
3. **Gamma-Beta Ratio (γ/β)** - Target margin
|
||||
4. **SFT Weight** - Regularization strength
|
||||
|
||||
## Learning Rate
|
||||
|
||||
### Recommended Ranges
|
||||
|
||||
**By model size**:
|
||||
| Model Size | Learning Rate | Notes |
|
||||
|------------|---------------|-------|
|
||||
| 1B-3B | 5e-7 to 1e-6 | Higher end safe |
|
||||
| 7B-8B | 3e-7 to 5e-7 | **Standard** |
|
||||
| 13B-30B | 1e-7 to 3e-7 | Lower for stability |
|
||||
| 70B+ | 5e-8 to 1e-7 | Very conservative |
|
||||
|
||||
**By task type**:
|
||||
| Task | Learning Rate | Reason |
|
||||
|------|---------------|--------|
|
||||
| General chat | 5e-7 | Standard |
|
||||
| Code generation | 3e-7 | **Precise reasoning** |
|
||||
| Math reasoning | 3e-7 | **Careful optimization** |
|
||||
| Creative writing | 1e-6 | More aggressive OK |
|
||||
|
||||
### Why Learning Rate Matters
|
||||
|
||||
**Too high** (> 1e-6 for 7B):
|
||||
- Loss divergence
|
||||
- Catastrophic forgetting
|
||||
- Unstable training
|
||||
|
||||
**Too low** (< 1e-7 for 7B):
|
||||
- Very slow convergence
|
||||
- May not finish in time
|
||||
- Undertraining
|
||||
|
||||
**Optimal** (3e-7 to 5e-7 for 7B):
|
||||
- Stable convergence
|
||||
- Good final performance
|
||||
- Efficient training
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Mistral 7B (general)**:
|
||||
```yaml
|
||||
learning_rate: 5e-7
|
||||
num_train_epochs: 1
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
```
|
||||
|
||||
**Llama 3 8B (reasoning)**:
|
||||
```yaml
|
||||
learning_rate: 3e-7
|
||||
num_train_epochs: 1
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
```
|
||||
|
||||
**Gemma 2 9B (creative)**:
|
||||
```yaml
|
||||
learning_rate: 1e-6
|
||||
num_train_epochs: 1
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: linear
|
||||
```
|
||||
|
||||
## Beta (β)
|
||||
|
||||
### Recommended Values
|
||||
|
||||
**Range**: 2.0 to 10.0 (much higher than DPO's 0.01-0.1)
|
||||
|
||||
**By preference strength**:
|
||||
| Beta | Preference Strength | Use Case |
|
||||
|------|-------------------|----------|
|
||||
| 1.0-2.0 | Weak | Subtle preferences |
|
||||
| 2.0-5.0 | **Standard** | General alignment |
|
||||
| 5.0-10.0 | Strong | Clear preferences |
|
||||
|
||||
**Default**: 2.0 to 2.5
|
||||
|
||||
### Why Beta Matters
|
||||
|
||||
**Low beta** (< 2.0):
|
||||
- Weak reward signal
|
||||
- Slow preference learning
|
||||
- May underfit
|
||||
|
||||
**High beta** (> 10.0):
|
||||
- Very strong reward signal
|
||||
- Risk of overfitting
|
||||
- May ignore weak preferences
|
||||
|
||||
**Optimal** (2.0-5.0):
|
||||
- Balanced reward scaling
|
||||
- Stable training
|
||||
- Good generalization
|
||||
|
||||
### Interaction with Gamma
|
||||
|
||||
**Beta and gamma together**:
|
||||
```
|
||||
Target margin in reward space = gamma
|
||||
Target margin in logit space = gamma / beta
|
||||
```
|
||||
|
||||
**Example**:
|
||||
```yaml
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
# Effective gamma = 2.0 * 0.5 = 1.0
|
||||
```
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Weak preferences**:
|
||||
```yaml
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.3 # Small margin
|
||||
```
|
||||
|
||||
**Standard**:
|
||||
```yaml
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5 # Default
|
||||
```
|
||||
|
||||
**Strong preferences**:
|
||||
```yaml
|
||||
beta: 5.0
|
||||
gamma_beta_ratio: 0.7 # Larger margin
|
||||
```
|
||||
|
||||
## Gamma-Beta Ratio (γ/β)
|
||||
|
||||
### Recommended Values
|
||||
|
||||
**Range**: 0.0 to 1.0
|
||||
|
||||
**By scenario**:
|
||||
| Ratio | Margin | Use Case |
|
||||
|-------|--------|----------|
|
||||
| 0.0-0.3 | Small | Weak preference data |
|
||||
| 0.4-0.6 | **Standard** | General use |
|
||||
| 0.7-1.0 | Large | Very clear preferences |
|
||||
|
||||
**Default**: 0.5
|
||||
|
||||
### Why Gamma Matters
|
||||
|
||||
**Low gamma** (< 0.3):
|
||||
- Small target margin
|
||||
- Less aggressive alignment
|
||||
- More conservative
|
||||
|
||||
**High gamma** (> 0.7):
|
||||
- Large target margin
|
||||
- Stronger alignment
|
||||
- More aggressive
|
||||
|
||||
**Optimal** (0.4-0.6):
|
||||
- Balanced margin
|
||||
- Stable training
|
||||
- Good alignment
|
||||
|
||||
### Mathematical Meaning
|
||||
|
||||
**In loss function**:
|
||||
```python
|
||||
logits = pi_logratios - gamma_beta_ratio
|
||||
loss = -log(sigmoid(beta * logits))
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- gamma_beta_ratio shifts the decision boundary
|
||||
- Higher ratio = requires larger log prob difference
|
||||
- Controls how "clear" preferences must be
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Noisy preferences**:
|
||||
```yaml
|
||||
gamma_beta_ratio: 0.3 # Smaller margin, more tolerant
|
||||
```
|
||||
|
||||
**Standard**:
|
||||
```yaml
|
||||
gamma_beta_ratio: 0.5 # Default
|
||||
```
|
||||
|
||||
**High-quality preferences**:
|
||||
```yaml
|
||||
gamma_beta_ratio: 0.8 # Larger margin, stricter
|
||||
```
|
||||
|
||||
## SFT Weight
|
||||
|
||||
### Recommended Values
|
||||
|
||||
**Range**: 0.0 to 1.0
|
||||
|
||||
**By model type**:
|
||||
| Model Type | SFT Weight | Reason |
|
||||
|------------|-----------|--------|
|
||||
| Base model | 0.0 | No prior capabilities |
|
||||
| **Instruct model** | 0.05-0.1 | Preserve instruction following |
|
||||
| Chat model | 0.1-0.2 | Preserve conversational skills |
|
||||
|
||||
**Default**: 0.0 (no SFT regularization)
|
||||
|
||||
### Why SFT Weight Matters
|
||||
|
||||
**Zero SFT** (0.0):
|
||||
- Pure preference optimization
|
||||
- May forget capabilities
|
||||
- Standard for base models
|
||||
|
||||
**Low SFT** (0.05-0.1):
|
||||
- Balanced approach
|
||||
- **Recommended for instruct models**
|
||||
- Slight capability preservation
|
||||
|
||||
**High SFT** (> 0.2):
|
||||
- Strong capability preservation
|
||||
- Weaker preference alignment
|
||||
- May reduce alignment gains
|
||||
|
||||
### Trade-off
|
||||
|
||||
```
|
||||
Total Loss = SimPO Loss + (sft_weight * SFT Loss)
|
||||
```
|
||||
|
||||
**Example**:
|
||||
```yaml
|
||||
sft_weight: 0.1
|
||||
# 90% preference optimization + 10% capability preservation
|
||||
```
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Base model (no SFT)**:
|
||||
```yaml
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
sft_weight: 0.0
|
||||
```
|
||||
|
||||
**Instruct model (light SFT)**:
|
||||
```yaml
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
sft_weight: 0.1
|
||||
```
|
||||
|
||||
**Chat model (moderate SFT)**:
|
||||
```yaml
|
||||
model_name_or_path: HuggingFaceH4/zephyr-7b-beta
|
||||
sft_weight: 0.2
|
||||
```
|
||||
|
||||
## Model-Size-Specific Recommendations
|
||||
|
||||
### 7B Models (Mistral, Llama 3)
|
||||
|
||||
**Standard config**:
|
||||
```yaml
|
||||
learning_rate: 5e-7
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.0 # 0.1 if instruct model
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
```
|
||||
|
||||
### 8B-13B Models
|
||||
|
||||
**Standard config**:
|
||||
```yaml
|
||||
learning_rate: 3e-7
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.1 # If instruct
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
```
|
||||
|
||||
### 70B Models
|
||||
|
||||
**Standard config**:
|
||||
```yaml
|
||||
learning_rate: 1e-7
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.05
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
```
|
||||
|
||||
## Batch Size & Gradient Accumulation
|
||||
|
||||
### Effective Batch Size
|
||||
|
||||
```
|
||||
Effective Batch Size = per_device_batch_size * num_gpus * grad_accum_steps
|
||||
```
|
||||
|
||||
**Recommended effective batch sizes**:
|
||||
- 7B: 128-256
|
||||
- 13B: 64-128
|
||||
- 70B: 32-64
|
||||
|
||||
### Config Examples
|
||||
|
||||
**Single GPU (A100 40GB)**:
|
||||
```yaml
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 128 # Effective batch = 128
|
||||
```
|
||||
|
||||
**4 GPUs (A100 40GB)**:
|
||||
```yaml
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 16 # Effective batch = 2*4*16 = 128
|
||||
```
|
||||
|
||||
**8 GPUs (A100 80GB)**:
|
||||
```yaml
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 8 # Effective batch = 2*8*8 = 128
|
||||
```
|
||||
|
||||
## Loss Type
|
||||
|
||||
### Sigmoid vs Hinge
|
||||
|
||||
**Sigmoid** (default, recommended):
|
||||
```yaml
|
||||
loss_type: sigmoid
|
||||
label_smoothing: 0.0
|
||||
```
|
||||
|
||||
**Hinge** (experimental):
|
||||
```yaml
|
||||
loss_type: hinge
|
||||
# No label smoothing for hinge
|
||||
```
|
||||
|
||||
**When to use hinge**:
|
||||
- Margin-based tasks
|
||||
- SVM-style optimization
|
||||
- Experimental purposes
|
||||
|
||||
**Generally**: Stick with sigmoid
|
||||
|
||||
## Tuning Guide
|
||||
|
||||
### Step 1: Start with Defaults
|
||||
|
||||
```yaml
|
||||
learning_rate: 5e-7 # For 7B
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
sft_weight: 0.0 # 0.1 if instruct
|
||||
loss_type: sigmoid
|
||||
```
|
||||
|
||||
### Step 2: Monitor Training
|
||||
|
||||
**Check every 100 steps**:
|
||||
- Loss curve (should decrease smoothly)
|
||||
- Reward margin (should increase)
|
||||
- Chosen/rejected logps (should separate)
|
||||
|
||||
### Step 3: Adjust if Needed
|
||||
|
||||
**If loss diverges**:
|
||||
```yaml
|
||||
learning_rate: 3e-7 # Reduce from 5e-7
|
||||
beta: 1.0 # Reduce from 2.0
|
||||
```
|
||||
|
||||
**If loss plateaus early**:
|
||||
```yaml
|
||||
learning_rate: 1e-6 # Increase from 5e-7
|
||||
beta: 5.0 # Increase from 2.0
|
||||
```
|
||||
|
||||
**If model forgets**:
|
||||
```yaml
|
||||
sft_weight: 0.2 # Increase from 0.0
|
||||
```
|
||||
|
||||
## Complete Example Configs
|
||||
|
||||
### Mistral 7B Base (Standard)
|
||||
|
||||
```yaml
|
||||
model_name_or_path: mistralai/Mistral-7B-v0.1
|
||||
dataset_mixer:
|
||||
HuggingFaceH4/ultrafeedback_binarized: 1.0
|
||||
|
||||
learning_rate: 5e-7
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
loss_type: sigmoid
|
||||
sft_weight: 0.0
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
|
||||
bf16: true
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
### Llama 3 8B Instruct (Reasoning)
|
||||
|
||||
```yaml
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
dataset_mixer:
|
||||
argilla/distilabel-math-preference-dpo: 1.0
|
||||
|
||||
learning_rate: 3e-7
|
||||
beta: 5.0
|
||||
gamma_beta_ratio: 0.7
|
||||
loss_type: sigmoid
|
||||
sft_weight: 0.1
|
||||
|
||||
num_train_epochs: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: cosine
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- SimPO paper: https://arxiv.org/abs/2405.14734
|
||||
- Alignment Handbook: https://github.com/huggingface/alignment-handbook
|
||||
350
skills/mlops/simpo/references/loss-functions.md
Normal file
350
skills/mlops/simpo/references/loss-functions.md
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
# Loss Functions
|
||||
|
||||
Complete guide to SimPO loss functions and mathematical formulations.
|
||||
|
||||
## Overview
|
||||
|
||||
SimPO supports two loss types:
|
||||
- **Sigmoid** (default) - Smooth, differentiable loss
|
||||
- **Hinge** - Margin-based, sparse loss
|
||||
|
||||
Both are reference-free (no reference model needed).
|
||||
|
||||
## SimPO Loss Formula
|
||||
|
||||
### Core Calculation
|
||||
|
||||
**Step 1: Log probability ratio**:
|
||||
```
|
||||
pi_logratios = log P_θ(y_chosen|x) - log P_θ(y_rejected|x)
|
||||
```
|
||||
|
||||
**Step 2: Apply target margin**:
|
||||
```
|
||||
logits = pi_logratios - γ/β
|
||||
```
|
||||
Where:
|
||||
- γ/β = `gamma_beta_ratio` (target margin)
|
||||
|
||||
**Step 3: Compute loss** (depends on loss type)
|
||||
|
||||
### Sigmoid Loss (Default)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L = -log σ(β * logits) * (1 - ε) - log σ(-β * logits) * ε
|
||||
```
|
||||
|
||||
Where:
|
||||
- β = `beta` (reward scaling)
|
||||
- σ = sigmoid function
|
||||
- ε = `label_smoothing` (default 0.0)
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- Smooth, continuous gradients
|
||||
- Probabilistic interpretation
|
||||
- Standard choice for most tasks
|
||||
- Works well with higher beta values
|
||||
|
||||
### Hinge Loss
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L = max(0, 1 - β * logits)
|
||||
```
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
losses = torch.relu(1 - self.beta * logits)
|
||||
```
|
||||
|
||||
**Characteristics**:
|
||||
- Non-smooth (has kink at logits = 1/β)
|
||||
- Margin-based (SVM-style)
|
||||
- Can lead to sparser solutions
|
||||
- Less commonly used
|
||||
|
||||
## Comparison to DPO
|
||||
|
||||
### DPO Loss (Reference Model Required)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L_DPO = -E[log σ(β * log(π_θ(y_w|x)/π_ref(y_w|x)) - β * log(π_θ(y_l|x)/π_ref(y_l|x)))]
|
||||
```
|
||||
|
||||
**Key features**:
|
||||
- Requires reference model π_ref
|
||||
- Normalizes by reference log probabilities
|
||||
- More conservative (stays close to reference)
|
||||
|
||||
### SimPO Loss (Reference-Free)
|
||||
|
||||
**Formula**:
|
||||
```
|
||||
L_SimPO = -log σ(β * (log π_θ(y_w|x) - log π_θ(y_l|x) - γ/β))
|
||||
```
|
||||
|
||||
**Key features**:
|
||||
- No reference model needed
|
||||
- Direct preference optimization
|
||||
- Target margin γ/β controls preference strength
|
||||
- More efficient (fewer model forward passes)
|
||||
|
||||
**Visual comparison**:
|
||||
```
|
||||
DPO: [Policy] - [Reference] → Loss
|
||||
SimPO: [Policy] → Loss
|
||||
```
|
||||
|
||||
## Average Log Probability Reward
|
||||
|
||||
### Calculation
|
||||
|
||||
**Per-token log probabilities**:
|
||||
```python
|
||||
# Get log probs for each token
|
||||
per_token_logps = log_softmax(logits).gather(dim=-1, index=labels)
|
||||
|
||||
# Create mask to ignore padding
|
||||
loss_mask = (labels != label_pad_token_id)
|
||||
```
|
||||
|
||||
**Average log probability** (if `average_log_prob=True`):
|
||||
```python
|
||||
avg_logp = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
```
|
||||
|
||||
**Sum log probability** (if `average_log_prob=False`):
|
||||
```python
|
||||
sum_logp = (per_token_logps * loss_mask).sum(-1)
|
||||
```
|
||||
|
||||
**Why average?**
|
||||
- Normalizes for sequence length
|
||||
- Prevents bias toward shorter/longer responses
|
||||
- Standard practice in SimPO
|
||||
|
||||
### Reward Metrics
|
||||
|
||||
**Chosen reward**:
|
||||
```python
|
||||
chosen_rewards = beta * policy_chosen_logps.detach()
|
||||
```
|
||||
|
||||
**Rejected reward**:
|
||||
```python
|
||||
rejected_rewards = beta * policy_rejected_logps.detach()
|
||||
```
|
||||
|
||||
**Reward margin**:
|
||||
```python
|
||||
reward_margin = chosen_rewards.mean() - rejected_rewards.mean()
|
||||
```
|
||||
|
||||
## Label Smoothing
|
||||
|
||||
### Formula with Smoothing
|
||||
|
||||
**Sigmoid loss**:
|
||||
```
|
||||
L = -log σ(β * logits) * (1 - ε) - log σ(-β * logits) * ε
|
||||
```
|
||||
|
||||
**Effect**:
|
||||
- ε = 0.0: No smoothing (default)
|
||||
- ε = 0.1: 10% smoothing (soft labels)
|
||||
- ε = 0.5: Maximum smoothing
|
||||
|
||||
**When to use**:
|
||||
- Noisy preference labels
|
||||
- Uncertain preferences
|
||||
- Prevent overconfidence
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
label_smoothing: 0.1 # 10% smoothing
|
||||
```
|
||||
|
||||
## SFT Regularization
|
||||
|
||||
### Combined Loss
|
||||
|
||||
**With SFT component**:
|
||||
```
|
||||
L_total = L_SimPO + λ * L_SFT
|
||||
```
|
||||
|
||||
Where:
|
||||
- L_SFT = cross-entropy loss on chosen responses
|
||||
- λ = `sft_weight` (0.0 to 1.0)
|
||||
|
||||
**Implementation**:
|
||||
```python
|
||||
if self.sft_weight > 0:
|
||||
sft_loss = -policy_chosen_logps
|
||||
total_loss = simpo_loss + self.sft_weight * sft_loss
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- Preserve model capabilities
|
||||
- Prevent catastrophic forgetting
|
||||
- Fine-tuning instruct models
|
||||
|
||||
**Trade-off**:
|
||||
- Higher sft_weight: Preserve capabilities, less alignment
|
||||
- Lower sft_weight: Stronger alignment, may forget capabilities
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
sft_weight: 0.1 # 10% SFT regularization
|
||||
```
|
||||
|
||||
## Loss Type Selection
|
||||
|
||||
### Sigmoid vs Hinge
|
||||
|
||||
| Aspect | Sigmoid | Hinge |
|
||||
|--------|---------|-------|
|
||||
| Smoothness | Smooth | Non-smooth |
|
||||
| Gradients | Continuous | Discontinuous at margin |
|
||||
| Sparsity | Dense solutions | Sparse solutions |
|
||||
| Interpretability | Probabilistic | Geometric margin |
|
||||
| Use case | **General purpose** | Margin-based tasks |
|
||||
| Recommendation | **Default choice** | Experimental |
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
# Sigmoid (default)
|
||||
loss_type: sigmoid
|
||||
|
||||
# Hinge (alternative)
|
||||
loss_type: hinge
|
||||
```
|
||||
|
||||
## Mathematical Properties
|
||||
|
||||
### Gradient Analysis
|
||||
|
||||
**Sigmoid loss gradient**:
|
||||
```
|
||||
∂L/∂logits = -β * σ(-β * logits) * (1 - ε) + β * σ(β * logits) * ε
|
||||
```
|
||||
|
||||
**Hinge loss gradient**:
|
||||
```
|
||||
∂L/∂logits = -β if logits < 1/β
|
||||
0 otherwise
|
||||
```
|
||||
|
||||
**Implications**:
|
||||
- Sigmoid: Always provides gradient signal
|
||||
- Hinge: No gradient when margin satisfied
|
||||
|
||||
### Convergence Behavior
|
||||
|
||||
**Sigmoid**:
|
||||
- Asymptotically approaches zero loss
|
||||
- Continues optimizing even with large margins
|
||||
- Smoother training curves
|
||||
|
||||
**Hinge**:
|
||||
- Reaches zero loss at margin
|
||||
- Stops optimizing once margin satisfied
|
||||
- May have training plateaus
|
||||
|
||||
## Complete Loss Examples
|
||||
|
||||
### Example 1: Basic SimPO (Sigmoid)
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
beta: 2.0
|
||||
gamma_beta_ratio: 0.5
|
||||
loss_type: sigmoid
|
||||
label_smoothing: 0.0
|
||||
sft_weight: 0.0
|
||||
```
|
||||
|
||||
**Loss calculation**:
|
||||
```python
|
||||
# Step 1: Compute log probs
|
||||
chosen_logps = avg_log_prob(policy(chosen)) # e.g., -1.2
|
||||
rejected_logps = avg_log_prob(policy(rejected)) # e.g., -2.5
|
||||
|
||||
# Step 2: Log ratio and margin
|
||||
pi_logratios = -1.2 - (-2.5) = 1.3
|
||||
logits = 1.3 - 0.5 = 0.8
|
||||
|
||||
# Step 3: Sigmoid loss
|
||||
loss = -log(sigmoid(2.0 * 0.8))
|
||||
= -log(sigmoid(1.6))
|
||||
= -log(0.832)
|
||||
= 0.184
|
||||
```
|
||||
|
||||
### Example 2: SimPO with SFT
|
||||
|
||||
**Config**:
|
||||
```yaml
|
||||
beta: 2.5
|
||||
gamma_beta_ratio: 0.5
|
||||
loss_type: sigmoid
|
||||
sft_weight: 0.1
|
||||
```
|
||||
|
||||
**Loss calculation**:
|
||||
```python
|
||||
# SimPO loss (as above)
|
||||
simpo_loss = 0.184
|
||||
|
||||
# SFT loss
|
||||
sft_loss = -chosen_logps = -(-1.2) = 1.2
|
||||
|
||||
# Total loss
|
||||
total_loss = simpo_loss + 0.1 * sft_loss
|
||||
= 0.184 + 0.12
|
||||
= 0.304
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
### Check Reward Margins
|
||||
|
||||
**Low margin (< 0.5)**:
|
||||
- Preferences not being learned
|
||||
- Increase beta or gamma_beta_ratio
|
||||
|
||||
**High margin (> 5.0)**:
|
||||
- May be overfitting
|
||||
- Reduce beta or learning rate
|
||||
|
||||
**Monitor**:
|
||||
```python
|
||||
reward_margin = chosen_rewards.mean() - rejected_rewards.mean()
|
||||
print(f"Reward margin: {reward_margin:.2f}")
|
||||
```
|
||||
|
||||
### Check Log Probabilities
|
||||
|
||||
**Typical values**:
|
||||
- Chosen: -1.0 to -2.0 (higher is better)
|
||||
- Rejected: -2.0 to -4.0 (lower is worse)
|
||||
|
||||
**Warning signs**:
|
||||
- Both very negative (< -10): Model not learning
|
||||
- Both very positive (> 0): Numerical instability
|
||||
|
||||
## References
|
||||
|
||||
- SimPO paper: https://arxiv.org/abs/2405.14734
|
||||
- DPO paper: https://arxiv.org/abs/2305.18290
|
||||
- Implementation: https://github.com/princeton-nlp/SimPO
|
||||
467
skills/mlops/slime/SKILL.md
Normal file
467
skills/mlops/slime/SKILL.md
Normal file
|
|
@ -0,0 +1,467 @@
|
|||
---
|
||||
name: slime-rl-training
|
||||
description: Provides guidance for LLM post-training with RL using slime, a Megatron+SGLang framework. Use when training GLM models, implementing custom data generation workflows, or needing tight Megatron-LM integration for RL scaling.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [sglang-router>=0.2.3, ray, torch>=2.0.0, transformers>=4.40.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Reinforcement Learning, Megatron-LM, SGLang, GRPO, Post-Training, GLM]
|
||||
|
||||
---
|
||||
|
||||
# slime: LLM Post-Training Framework for RL Scaling
|
||||
|
||||
slime is an LLM post-training framework from Tsinghua's THUDM team, powering GLM-4.5, GLM-4.6, and GLM-4.7. It connects Megatron-LM for training with SGLang for high-throughput rollout generation.
|
||||
|
||||
## When to Use slime
|
||||
|
||||
**Choose slime when you need:**
|
||||
- Megatron-LM native training with SGLang inference
|
||||
- Custom data generation workflows with flexible data buffers
|
||||
- Training GLM, Qwen3, DeepSeek V3, or Llama 3 models
|
||||
- Research-grade framework with production backing (Z.ai)
|
||||
|
||||
**Consider alternatives when:**
|
||||
- You need enterprise-grade stability features → use **miles**
|
||||
- You want flexible backend swapping → use **verl**
|
||||
- You need PyTorch-native abstractions → use **torchforge**
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Training**: Megatron-LM with full parallelism support (TP, PP, DP, SP)
|
||||
- **Rollout**: SGLang-based high-throughput generation with router
|
||||
- **Data Buffer**: Flexible prompt management and sample storage
|
||||
- **Models**: GLM-4.x, Qwen3, DeepSeek V3/R1, Llama 3
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Data Buffer │
|
||||
│ - Prompt initialization and management │
|
||||
│ - Custom data generation and filtering │
|
||||
│ - Rollout sample storage │
|
||||
└─────────────┬───────────────────────────┬───────────────┘
|
||||
│ │
|
||||
┌─────────────▼───────────┐ ┌─────────────▼───────────────┐
|
||||
│ Training (Megatron-LM) │ │ Rollout (SGLang + Router) │
|
||||
│ - Actor model training │ │ - Response generation │
|
||||
│ - Critic (optional) │ │ - Reward/verifier output │
|
||||
│ - Weight sync to rollout│ │ - Multi-turn support │
|
||||
└─────────────────────────┘ └─────────────────────────────┘
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Recommended: Docker
|
||||
docker pull slimerl/slime:latest
|
||||
docker run --rm --gpus all --ipc=host --shm-size=16g \
|
||||
-it slimerl/slime:latest /bin/bash
|
||||
|
||||
# Inside container
|
||||
cd /root/slime && pip install -e . --no-deps
|
||||
```
|
||||
|
||||
### From Source
|
||||
|
||||
```bash
|
||||
git clone https://github.com/THUDM/slime.git
|
||||
cd slime
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Quick Start: GRPO Training
|
||||
|
||||
```bash
|
||||
# Source model configuration
|
||||
source scripts/models/qwen3-4B.sh
|
||||
|
||||
# Launch training
|
||||
python train.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 4 \
|
||||
--rollout-num-gpus 4 \
|
||||
--advantage-estimator grpo \
|
||||
--use-kl-loss --kl-loss-coef 0.001 \
|
||||
--rollout-batch-size 32 \
|
||||
--n-samples-per-prompt 8 \
|
||||
--global-batch-size 256 \
|
||||
--num-rollout 3000 \
|
||||
--prompt-data /path/to/data.jsonl \
|
||||
${MODEL_ARGS[@]} ${CKPT_ARGS[@]}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Workflow 1: Standard GRPO Training
|
||||
|
||||
Use this workflow for training reasoning models with group-relative advantages.
|
||||
|
||||
### Prerequisites Checklist
|
||||
- [ ] Docker environment or Megatron-LM + SGLang installed
|
||||
- [ ] Model checkpoint (HuggingFace or Megatron format)
|
||||
- [ ] Training data in JSONL format
|
||||
|
||||
### Step 1: Prepare Data
|
||||
|
||||
```python
|
||||
# data.jsonl format
|
||||
{"prompt": "What is 2 + 2?", "label": "4"}
|
||||
{"prompt": "Solve: 3x = 12", "label": "x = 4"}
|
||||
```
|
||||
|
||||
Or with chat format:
|
||||
```python
|
||||
{
|
||||
"prompt": [
|
||||
{"role": "system", "content": "You are a math tutor."},
|
||||
{"role": "user", "content": "What is 15 + 27?"}
|
||||
],
|
||||
"label": "42"
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Configure Model
|
||||
|
||||
Choose a pre-configured model script:
|
||||
|
||||
```bash
|
||||
# List available models
|
||||
ls scripts/models/
|
||||
# glm4-9B.sh, qwen3-4B.sh, qwen3-30B-A3B.sh, deepseek-v3.sh, llama3-8B.sh, ...
|
||||
|
||||
# Source your model
|
||||
source scripts/models/qwen3-4B.sh
|
||||
```
|
||||
|
||||
### Step 3: Launch Training
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--advantage-estimator grpo \
|
||||
--use-kl-loss \
|
||||
--kl-loss-coef 0.001 \
|
||||
--prompt-data /path/to/train.jsonl \
|
||||
--input-key prompt \
|
||||
--label-key label \
|
||||
--apply-chat-template \
|
||||
--rollout-batch-size 32 \
|
||||
--n-samples-per-prompt 8 \
|
||||
--global-batch-size 256 \
|
||||
--num-rollout 3000 \
|
||||
--save-interval 100 \
|
||||
--eval-interval 50 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Step 4: Monitor Training
|
||||
- [ ] Check TensorBoard: `tensorboard --logdir outputs/`
|
||||
- [ ] Verify reward curves are increasing
|
||||
- [ ] Monitor GPU utilization across nodes
|
||||
|
||||
---
|
||||
|
||||
## Workflow 2: Asynchronous Training
|
||||
|
||||
Use async mode for higher throughput by overlapping rollout and training.
|
||||
|
||||
### When to Use Async
|
||||
- Large models with long generation times
|
||||
- High GPU idle time in synchronous mode
|
||||
- Sufficient memory for buffering
|
||||
|
||||
### Launch Async Training
|
||||
|
||||
```bash
|
||||
python train_async.py \
|
||||
--actor-num-nodes 1 \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--advantage-estimator grpo \
|
||||
--async-buffer-size 4 \
|
||||
--prompt-data /path/to/train.jsonl \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Async-Specific Parameters
|
||||
|
||||
```bash
|
||||
--async-buffer-size 4 # Number of rollouts to buffer
|
||||
--update-weights-interval 2 # Sync weights every N rollouts
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Workflow 3: Multi-Turn Agentic Training
|
||||
|
||||
Use this workflow for training agents with tool use or multi-step reasoning.
|
||||
|
||||
### Prerequisites
|
||||
- [ ] Custom generate function for multi-turn logic
|
||||
- [ ] Tool/environment interface
|
||||
|
||||
### Step 1: Define Custom Generate Function
|
||||
|
||||
```python
|
||||
# custom_generate.py
|
||||
async def custom_generate(args, samples, evaluation=False):
|
||||
"""Multi-turn generation with tool calling."""
|
||||
for sample in samples:
|
||||
conversation = sample.prompt
|
||||
|
||||
for turn in range(args.max_turns):
|
||||
# Generate response
|
||||
response = await generate_single(conversation)
|
||||
|
||||
# Check for tool call
|
||||
tool_call = extract_tool_call(response)
|
||||
if tool_call:
|
||||
tool_result = execute_tool(tool_call)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
conversation.append({"role": "tool", "content": tool_result})
|
||||
else:
|
||||
break
|
||||
|
||||
sample.response = response
|
||||
sample.reward = compute_reward(sample)
|
||||
|
||||
return samples
|
||||
```
|
||||
|
||||
### Step 2: Launch with Custom Function
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-generate-function-path custom_generate.py \
|
||||
--max-turns 5 \
|
||||
--prompt-data /path/to/agent_data.jsonl \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
See `examples/search-r1/` for a complete multi-turn search example.
|
||||
|
||||
---
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Three Argument Categories
|
||||
|
||||
slime uses three types of arguments:
|
||||
|
||||
**1. Megatron Arguments** (passed directly):
|
||||
```bash
|
||||
--tensor-model-parallel-size 2
|
||||
--pipeline-model-parallel-size 1
|
||||
--num-layers 32
|
||||
--hidden-size 4096
|
||||
```
|
||||
|
||||
**2. SGLang Arguments** (prefixed with `--sglang-`):
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.8
|
||||
--sglang-context-length 8192
|
||||
--sglang-log-level INFO
|
||||
```
|
||||
|
||||
**3. slime Arguments**:
|
||||
```bash
|
||||
# Resource allocation
|
||||
--actor-num-nodes 1
|
||||
--actor-num-gpus-per-node 8
|
||||
--rollout-num-gpus 8
|
||||
--colocate # Share GPUs between training/inference
|
||||
|
||||
# Data
|
||||
--prompt-data /path/to/data.jsonl
|
||||
--input-key prompt
|
||||
--label-key label
|
||||
|
||||
# Training loop
|
||||
--num-rollout 3000
|
||||
--rollout-batch-size 32
|
||||
--n-samples-per-prompt 8
|
||||
--global-batch-size 256
|
||||
|
||||
# Algorithm
|
||||
--advantage-estimator grpo # or: gspo, ppo, reinforce_plus_plus
|
||||
--use-kl-loss
|
||||
--kl-loss-coef 0.001
|
||||
```
|
||||
|
||||
### Key Constraints
|
||||
|
||||
```
|
||||
rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rollout
|
||||
```
|
||||
|
||||
Example: 32 × 8 = 256 × 1
|
||||
|
||||
---
|
||||
|
||||
## Data Buffer System
|
||||
|
||||
slime's data buffer enables flexible data management:
|
||||
|
||||
### Basic Data Source
|
||||
|
||||
```python
|
||||
class RolloutDataSource:
|
||||
def get_samples(self, num_samples):
|
||||
"""Fetch prompts from dataset."""
|
||||
return self.dataset.sample(num_samples)
|
||||
|
||||
def add_samples(self, samples):
|
||||
"""Called after generation (no-op by default)."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Buffered Data Source (Off-Policy)
|
||||
|
||||
```python
|
||||
class RolloutDataSourceWithBuffer(RolloutDataSource):
|
||||
def __init__(self):
|
||||
self.buffer = []
|
||||
|
||||
def add_samples(self, samples):
|
||||
"""Store generated samples for reuse."""
|
||||
self.buffer.extend(samples)
|
||||
|
||||
def buffer_filter(self, args, buffer, num_samples):
|
||||
"""Custom selection logic (prioritized, stratified, etc.)."""
|
||||
return select_best(buffer, num_samples)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Issue: SGLang Engine Crash
|
||||
|
||||
**Symptoms**: Inference engine dies mid-training
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Enable fault tolerance
|
||||
--use-fault-tolerance
|
||||
|
||||
# Increase memory allocation
|
||||
--sglang-mem-fraction-static 0.85
|
||||
|
||||
# Reduce batch size
|
||||
--rollout-batch-size 16
|
||||
```
|
||||
|
||||
### Issue: Weight Sync Timeout
|
||||
|
||||
**Symptoms**: Training hangs after rollout
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Increase sync interval
|
||||
--update-weights-interval 5
|
||||
|
||||
# Use colocated mode (no network transfer)
|
||||
--colocate
|
||||
```
|
||||
|
||||
### Issue: OOM During Training
|
||||
|
||||
**Symptoms**: CUDA OOM in backward pass
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Enable gradient checkpointing
|
||||
--recompute-activations
|
||||
|
||||
# Reduce micro-batch size
|
||||
--micro-batch-size 1
|
||||
|
||||
# Enable sequence parallelism
|
||||
--sequence-parallel
|
||||
```
|
||||
|
||||
### Issue: Slow Data Loading
|
||||
|
||||
**Symptoms**: GPU idle during data fetch
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Increase data workers
|
||||
--num-data-workers 4
|
||||
|
||||
# Use streaming dataset
|
||||
--streaming-data
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Family | Configurations |
|
||||
|--------------|----------------|
|
||||
| GLM | GLM-4.5, GLM-4.6, GLM-4.7, GLM-Z1-9B |
|
||||
| Qwen | Qwen3 (4B, 8B, 30B-A3B), Qwen3-MoE, Qwen2.5 |
|
||||
| DeepSeek | V3, V3.1, R1 |
|
||||
| Llama | Llama 3 (8B, 70B) |
|
||||
| Others | Kimi K2, Moonlight-16B |
|
||||
|
||||
Each model has pre-configured scripts in `scripts/models/`.
|
||||
|
||||
---
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Co-location Mode
|
||||
|
||||
Share GPUs between training and inference to reduce memory:
|
||||
|
||||
```bash
|
||||
python train.py \
|
||||
--colocate \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--sglang-mem-fraction-static 0.4 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Custom Reward Model
|
||||
|
||||
```python
|
||||
# custom_rm.py
|
||||
class CustomRewardModel:
|
||||
def __init__(self, model_path):
|
||||
self.model = load_model(model_path)
|
||||
|
||||
def compute_reward(self, prompts, responses):
|
||||
inputs = self.tokenize(prompts, responses)
|
||||
scores = self.model(inputs)
|
||||
return scores.tolist()
|
||||
```
|
||||
|
||||
```bash
|
||||
--custom-rm-path custom_rm.py
|
||||
```
|
||||
|
||||
### Evaluation Multi-Task
|
||||
|
||||
```bash
|
||||
--eval-prompt-data aime /path/to/aime.jsonl \
|
||||
--eval-prompt-data gsm8k /path/to/gsm8k.jsonl \
|
||||
--n-samples-per-eval-prompt 16
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://thudm.github.io/slime/
|
||||
- **GitHub**: https://github.com/THUDM/slime
|
||||
- **Blog**: https://lmsys.org/blog/2025-07-09-slime/
|
||||
- **Examples**: See `examples/` directory for 14+ worked examples
|
||||
|
||||
392
skills/mlops/slime/references/api-reference.md
Normal file
392
skills/mlops/slime/references/api-reference.md
Normal file
|
|
@ -0,0 +1,392 @@
|
|||
# slime API Reference
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
slime operates with a three-module architecture orchestrated by Ray:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Data Buffer │
|
||||
│ - Prompt initialization and management │
|
||||
│ - Custom data generation and filtering │
|
||||
│ - Rollout sample storage │
|
||||
└─────────────┬───────────────────────────┬───────────────┘
|
||||
│ │
|
||||
┌─────────────▼───────────┐ ┌─────────────▼───────────────┐
|
||||
│ Training (Megatron-LM) │ │ Rollout (SGLang + Router) │
|
||||
│ - Actor model training │ │ - Response generation │
|
||||
│ - Critic (optional) │ │ - Reward/verifier output │
|
||||
│ - Weight sync to rollout│ │ - Multi-turn support │
|
||||
└─────────────────────────┘ └─────────────────────────────┘
|
||||
```
|
||||
|
||||
## Core Data Structures
|
||||
|
||||
### Sample Object
|
||||
|
||||
The `Sample` object is the core data structure defined in `slime/utils/types.py`:
|
||||
|
||||
```python
|
||||
from slime.utils.types import Sample
|
||||
|
||||
@dataclass
|
||||
class Sample:
|
||||
# Core fields
|
||||
group_index: Optional[int] # Group index for batching
|
||||
index: Optional[int] # Sample index
|
||||
prompt: str | list[dict] = "" # Input prompt or chat history
|
||||
tokens: list[int] = field(default_factory=list) # Token IDs
|
||||
response: str = "" # Generated response
|
||||
response_length: int = 0 # Response length in tokens
|
||||
label: Optional[str] = None # Ground truth label
|
||||
reward: Optional[float | dict] = None # RL reward signal
|
||||
loss_mask: Optional[list[int]] = None # 1=compute loss, 0=mask
|
||||
status: Status = Status.PENDING # Sample status
|
||||
metadata: dict = field(default_factory=dict) # Custom data
|
||||
|
||||
# Multimodal support
|
||||
multimodal_inputs: Optional[Any] = None # Raw multimodal data (images, videos)
|
||||
multimodal_train_inputs: Optional[Any] = None # Processed multimodal data (pixel_values)
|
||||
|
||||
# Rollout tracking
|
||||
weight_versions: list[str] = field(default_factory=list)
|
||||
rollout_log_probs: Optional[list[float]] = None # Log probs from SGLang
|
||||
rollout_routed_experts: Optional[list[list[int]]] = None # Expert routing (MoE)
|
||||
|
||||
# Control fields
|
||||
remove_sample: bool = False
|
||||
generate_function_path: Optional[str] = None
|
||||
train_metadata: Optional[dict] = None
|
||||
non_generation_time: float = 0.0
|
||||
|
||||
# Speculative decoding info (nested dataclass)
|
||||
@dataclass
|
||||
class SpecInfo:
|
||||
spec_accept_token_num: int = 0
|
||||
spec_draft_token_num: int = 0
|
||||
spec_verify_ct: int = 0
|
||||
completion_token_num: int = 0
|
||||
```
|
||||
|
||||
### Status Enum
|
||||
|
||||
```python
|
||||
class Status(Enum):
|
||||
PENDING = "pending" # Not yet processed
|
||||
COMPLETED = "completed" # Successfully generated
|
||||
TRUNCATED = "truncated" # Hit max length
|
||||
ABORTED = "aborted" # Failed generation
|
||||
FAILED = "failed" # Generation failed
|
||||
```
|
||||
|
||||
## Configuration System
|
||||
|
||||
slime uses three categories of command-line arguments:
|
||||
|
||||
### 1. Megatron Arguments
|
||||
|
||||
All Megatron-LM arguments are supported directly:
|
||||
|
||||
```bash
|
||||
--tensor-model-parallel-size 2
|
||||
--pipeline-model-parallel-size 1
|
||||
--num-layers 32
|
||||
--hidden-size 4096
|
||||
--num-attention-heads 32
|
||||
--seq-length 4096
|
||||
--micro-batch-size 1
|
||||
--global-batch-size 256
|
||||
```
|
||||
|
||||
### 2. SGLang Arguments
|
||||
|
||||
SGLang arguments are prefixed with `--sglang-`:
|
||||
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.8 # GPU memory for KV cache
|
||||
--sglang-context-length 8192 # Maximum context length
|
||||
--sglang-log-level INFO # Logging verbosity
|
||||
--sglang-tp-size 2 # Tensor parallelism
|
||||
--sglang-disable-cuda-graph # Disable CUDA graphs
|
||||
```
|
||||
|
||||
### 3. slime-Specific Arguments
|
||||
|
||||
Defined in `slime/utils/arguments.py`:
|
||||
|
||||
```bash
|
||||
# Resource Allocation
|
||||
--actor-num-nodes 1 # Training nodes
|
||||
--actor-num-gpus-per-node 8 # GPUs per training node
|
||||
--rollout-num-gpus 8 # Total rollout GPUs
|
||||
--rollout-num-gpus-per-engine 2 # GPUs per SGLang engine
|
||||
--colocate # Share GPUs for train/inference
|
||||
|
||||
# Data Configuration
|
||||
--prompt-data /path/to/data.jsonl # Training data path
|
||||
--input-key prompt # Key for prompts in JSON
|
||||
--label-key label # Key for labels in JSON
|
||||
--apply-chat-template # Apply chat formatting
|
||||
|
||||
# Training Loop
|
||||
--num-rollout 3000 # Total rollout iterations
|
||||
--rollout-batch-size 32 # Prompts per rollout
|
||||
--n-samples-per-prompt 8 # Responses per prompt
|
||||
--global-batch-size 256 # Training batch size
|
||||
--num-steps-per-rollout 1 # Training steps per rollout
|
||||
|
||||
# RL Algorithm
|
||||
--advantage-estimator grpo # grpo, gspo, ppo, reinforce_plus_plus
|
||||
--use-kl-loss # Enable KL loss
|
||||
--kl-loss-coef 0.001 # KL coefficient
|
||||
--calculate-per-token-loss # Token-level loss
|
||||
|
||||
# Off-Policy Options
|
||||
--use-tis # Truncated Importance Sampling
|
||||
--tis-threshold 0.9 # TIS threshold
|
||||
--true-on-policy-mode # Force on-policy training
|
||||
```
|
||||
|
||||
## Data Buffer System
|
||||
|
||||
### RolloutDataSource (Base Class)
|
||||
|
||||
```python
|
||||
from slime.data import RolloutDataSource
|
||||
|
||||
class RolloutDataSource:
|
||||
def __init__(self, dataset, args):
|
||||
self.dataset = dataset
|
||||
self.args = args
|
||||
|
||||
def get_samples(self, num_samples: int) -> list[Sample]:
|
||||
"""Fetch prompts from dataset."""
|
||||
return [Sample(prompt=p) for p in self.dataset.sample(num_samples)]
|
||||
|
||||
def add_samples(self, samples: list[Sample]) -> None:
|
||||
"""Called after generation (no-op by default)."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Buffered Data Source (Off-Policy)
|
||||
|
||||
```python
|
||||
from slime.data import RolloutDataSourceWithBuffer
|
||||
|
||||
class RolloutDataSourceWithBuffer(RolloutDataSource):
|
||||
def __init__(self, dataset, args):
|
||||
super().__init__(dataset, args)
|
||||
self.buffer = []
|
||||
|
||||
def add_samples(self, samples: list[Sample]) -> None:
|
||||
"""Store generated samples for reuse."""
|
||||
self.buffer.extend(samples)
|
||||
|
||||
def buffer_filter(self, args, buffer, num_samples) -> list[Sample]:
|
||||
"""Custom selection logic."""
|
||||
# Example: prioritized sampling based on reward
|
||||
sorted_buffer = sorted(buffer, key=lambda s: s.reward, reverse=True)
|
||||
return sorted_buffer[:num_samples]
|
||||
```
|
||||
|
||||
## Custom Functions
|
||||
|
||||
### Custom Generate Function
|
||||
|
||||
For multi-turn or tool-calling scenarios:
|
||||
|
||||
```python
|
||||
# custom_generate.py
|
||||
from slime.data import Sample
|
||||
|
||||
async def custom_generate(args, samples: list[Sample], evaluation: bool = False) -> list[Sample]:
|
||||
"""
|
||||
Custom generation function for multi-turn interactions.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
samples: List of Sample objects with prompts
|
||||
evaluation: Whether this is an evaluation run
|
||||
|
||||
Returns:
|
||||
List of Sample objects with responses and rewards
|
||||
"""
|
||||
for sample in samples:
|
||||
conversation = sample.prompt if isinstance(sample.prompt, list) else [
|
||||
{"role": "user", "content": sample.prompt}
|
||||
]
|
||||
|
||||
for turn in range(args.max_turns):
|
||||
# Generate response
|
||||
response = await generate_single(conversation)
|
||||
|
||||
# Check for tool call
|
||||
tool_call = extract_tool_call(response)
|
||||
if tool_call:
|
||||
# Execute tool
|
||||
tool_result = await execute_tool(tool_call)
|
||||
conversation.append({"role": "assistant", "content": response})
|
||||
conversation.append({"role": "tool", "content": tool_result})
|
||||
else:
|
||||
# Final response
|
||||
sample.response = response
|
||||
break
|
||||
|
||||
# Compute reward
|
||||
sample.reward = compute_reward(sample)
|
||||
|
||||
# Set loss mask (1 for model tokens, 0 for tool responses)
|
||||
sample.loss_mask = build_loss_mask(sample)
|
||||
|
||||
return samples
|
||||
```
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-generate-function-path custom_generate.py \
|
||||
--max-turns 5
|
||||
```
|
||||
|
||||
### Custom Reward Function
|
||||
|
||||
```python
|
||||
# custom_rm.py
|
||||
from slime.data import Sample
|
||||
|
||||
async def reward_func(args, sample: Sample, **kwargs) -> float:
|
||||
"""
|
||||
Compute reward for a single sample.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
sample: Sample object with response
|
||||
|
||||
Returns:
|
||||
Reward score (float)
|
||||
"""
|
||||
response = sample.response
|
||||
ground_truth = sample.label or sample.metadata.get("answer", "")
|
||||
|
||||
# Example: exact match reward
|
||||
if response.strip() == ground_truth.strip():
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
# For batched processing (more efficient)
|
||||
async def batched_custom_rm(args, samples: list[Sample]) -> list[float]:
|
||||
"""Batch reward computation."""
|
||||
rewards = []
|
||||
for sample in samples:
|
||||
reward = await reward_func(args, sample)
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
```
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python train.py \
|
||||
--custom-rm-path custom_rm.py \
|
||||
--group-rm # Enable batched processing
|
||||
```
|
||||
|
||||
## Model Configuration
|
||||
|
||||
### Pre-configured Model Scripts
|
||||
|
||||
Located in `scripts/models/`:
|
||||
|
||||
```bash
|
||||
# List available models
|
||||
ls scripts/models/
|
||||
# glm4-9B.sh, qwen3-4B.sh, qwen3-30B-A3B.sh, deepseek-v3.sh, llama3-8B.sh
|
||||
|
||||
# Source model configuration
|
||||
source scripts/models/qwen3-4B.sh
|
||||
# This sets MODEL_ARGS and CKPT_ARGS arrays
|
||||
```
|
||||
|
||||
### Example Model Script
|
||||
|
||||
```bash
|
||||
# scripts/models/qwen3-4B.sh
|
||||
export MODEL_ARGS=(
|
||||
--num-layers 36
|
||||
--hidden-size 2560
|
||||
--num-attention-heads 20
|
||||
--num-query-groups 4
|
||||
--ffn-hidden-size 6912
|
||||
--max-position-embeddings 32768
|
||||
--rotary-percent 1.0
|
||||
--rotary-base 1000000
|
||||
--swiglu
|
||||
--untie-embeddings-and-output-weights
|
||||
--no-position-embedding
|
||||
--normalization RMSNorm
|
||||
--tokenizer-type HuggingFaceTokenizer
|
||||
--bf16
|
||||
)
|
||||
|
||||
export CKPT_ARGS=(
|
||||
--hf-checkpoint /path/to/qwen3-4b-hf
|
||||
--initial-megatron-checkpoint /path/to/megatron/ckpt
|
||||
)
|
||||
```
|
||||
|
||||
## Async Training
|
||||
|
||||
### Enabling Async Mode
|
||||
|
||||
```bash
|
||||
python train_async.py \
|
||||
--actor-num-gpus-per-node 8 \
|
||||
--rollout-num-gpus 8 \
|
||||
--async-buffer-size 4 \
|
||||
--update-weights-interval 2 \
|
||||
${MODEL_ARGS[@]}
|
||||
```
|
||||
|
||||
### Async-Specific Parameters
|
||||
|
||||
```bash
|
||||
--async-buffer-size 4 # Number of rollouts to buffer
|
||||
--update-weights-interval 2 # Sync weights every N rollouts
|
||||
```
|
||||
|
||||
**Note**: Colocated mode (`--colocate`) is NOT supported with async training.
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Multi-Task Evaluation
|
||||
|
||||
```bash
|
||||
--eval-prompt-data aime /path/to/aime.jsonl \
|
||||
--eval-prompt-data gsm8k /path/to/gsm8k.jsonl \
|
||||
--n-samples-per-eval-prompt 16 \
|
||||
--eval-interval 50
|
||||
```
|
||||
|
||||
### Evaluation Configuration
|
||||
|
||||
```bash
|
||||
--eval-interval 50 # Evaluate every N rollouts
|
||||
--n-samples-per-eval-prompt 16 # Samples for evaluation
|
||||
--eval-temperature 0.0 # Greedy decoding for eval
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model Family | Configurations |
|
||||
|--------------|----------------|
|
||||
| GLM | GLM-4.5, GLM-4.6, GLM-4.7, GLM-Z1-9B |
|
||||
| Qwen | Qwen3 (4B, 8B, 30B-A3B), Qwen3-MoE, Qwen2.5 |
|
||||
| DeepSeek | V3, V3.1, R1 |
|
||||
| Llama | Llama 3 (8B, 70B) |
|
||||
| Others | Kimi K2, Moonlight-16B |
|
||||
|
||||
## Resources
|
||||
|
||||
- Documentation: https://thudm.github.io/slime/
|
||||
- GitHub: https://github.com/THUDM/slime
|
||||
- Blog: https://lmsys.org/blog/2025-07-09-slime/
|
||||
- Examples: `examples/` directory (14+ worked examples)
|
||||
386
skills/mlops/slime/references/troubleshooting.md
Normal file
386
skills/mlops/slime/references/troubleshooting.md
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
# slime Troubleshooting Guide
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### SGLang Issues
|
||||
|
||||
#### Issue: SGLang Engine Crash
|
||||
|
||||
**Symptoms**: Inference engine dies mid-training, connection errors
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable fault tolerance**:
|
||||
```bash
|
||||
--use-fault-tolerance
|
||||
```
|
||||
|
||||
2. **Increase memory allocation**:
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.85 # Increase from 0.8
|
||||
```
|
||||
|
||||
3. **Reduce batch size**:
|
||||
```bash
|
||||
--rollout-batch-size 16 # Reduce from 32
|
||||
```
|
||||
|
||||
4. **Disable CUDA graphs** (for debugging):
|
||||
```bash
|
||||
--sglang-disable-cuda-graph
|
||||
```
|
||||
|
||||
#### Issue: SGLang Router Load Imbalance
|
||||
|
||||
**Symptoms**: Some SGLang engines overloaded while others idle
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Adjust routing strategy**:
|
||||
```bash
|
||||
--sglang-router-strategy round_robin
|
||||
```
|
||||
|
||||
2. **Increase number of engines**:
|
||||
```bash
|
||||
--rollout-num-gpus-per-engine 1 # More engines, less GPUs each
|
||||
```
|
||||
|
||||
### Weight Synchronization Issues
|
||||
|
||||
#### Issue: Weight Sync Timeout
|
||||
|
||||
**Symptoms**: Training hangs after rollout, timeout errors
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase sync interval** (async mode):
|
||||
```bash
|
||||
--update-weights-interval 5 # Increase from 2
|
||||
```
|
||||
|
||||
2. **Use colocated mode** (eliminates network transfer):
|
||||
```bash
|
||||
--colocate
|
||||
```
|
||||
|
||||
3. **Check network bandwidth**:
|
||||
```bash
|
||||
# Verify InfiniBand is enabled
|
||||
ibstat
|
||||
```
|
||||
|
||||
#### Issue: Weight Sync Failures in Multi-Node
|
||||
|
||||
**Symptoms**: Nodes fail to receive updated weights
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Set NCCL environment**:
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
export NCCL_IB_DISABLE=0
|
||||
```
|
||||
|
||||
2. **Increase timeout**:
|
||||
```bash
|
||||
export NCCL_TIMEOUT=1800
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
|
||||
#### Issue: OOM During Training
|
||||
|
||||
**Symptoms**: CUDA OOM in backward pass
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable gradient checkpointing**:
|
||||
```bash
|
||||
--recompute-activations
|
||||
```
|
||||
|
||||
2. **Reduce micro-batch size**:
|
||||
```bash
|
||||
--micro-batch-size 1
|
||||
```
|
||||
|
||||
3. **Enable sequence parallelism**:
|
||||
```bash
|
||||
--sequence-parallel
|
||||
```
|
||||
|
||||
4. **Reduce global batch size**:
|
||||
```bash
|
||||
--global-batch-size 128 # Reduce from 256
|
||||
```
|
||||
|
||||
#### Issue: OOM in Colocated Mode
|
||||
|
||||
**Symptoms**: OOM when both training and inference run on same GPUs
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce SGLang memory**:
|
||||
```bash
|
||||
--sglang-mem-fraction-static 0.4 # Reduce from 0.8
|
||||
```
|
||||
|
||||
2. **Enable offloading**:
|
||||
```bash
|
||||
--offload-optimizer-states
|
||||
```
|
||||
|
||||
3. **Use smaller sequence length**:
|
||||
```bash
|
||||
--seq-length 2048 # Reduce from 4096
|
||||
```
|
||||
|
||||
### Data Loading Issues
|
||||
|
||||
#### Issue: Slow Data Loading
|
||||
|
||||
**Symptoms**: GPU idle during data fetch, low GPU utilization
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase data workers**:
|
||||
```bash
|
||||
--num-data-workers 4
|
||||
```
|
||||
|
||||
2. **Use streaming dataset**:
|
||||
```bash
|
||||
--streaming-data
|
||||
```
|
||||
|
||||
3. **Pre-tokenize data**:
|
||||
```python
|
||||
# Pre-process data offline
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("model_path")
|
||||
# Save tokenized data
|
||||
```
|
||||
|
||||
#### Issue: Data Format Errors
|
||||
|
||||
**Symptoms**: KeyError, missing fields, parsing failures
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify data format**:
|
||||
```python
|
||||
import json
|
||||
with open("data.jsonl") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
assert "prompt" in data, "Missing prompt field"
|
||||
assert "label" in data, "Missing label field"
|
||||
```
|
||||
|
||||
2. **Check key names**:
|
||||
```bash
|
||||
--input-key prompt # Must match your data
|
||||
--label-key label # Must match your data
|
||||
```
|
||||
|
||||
### Training Stability Issues
|
||||
|
||||
#### Issue: Loss Explosion / NaN
|
||||
|
||||
**Symptoms**: Loss becomes NaN or explodes
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce learning rate**:
|
||||
```bash
|
||||
--lr 1e-6 # Reduce from 5e-6
|
||||
```
|
||||
|
||||
2. **Enable gradient clipping**:
|
||||
```bash
|
||||
--clip-grad 1.0
|
||||
```
|
||||
|
||||
3. **Check for data issues**:
|
||||
```python
|
||||
# Verify no empty prompts or responses
|
||||
for sample in dataset:
|
||||
assert len(sample["prompt"]) > 0
|
||||
```
|
||||
|
||||
4. **Use BF16 instead of FP16**:
|
||||
```bash
|
||||
--bf16 # More numerically stable
|
||||
```
|
||||
|
||||
#### Issue: Reward Collapse
|
||||
|
||||
**Symptoms**: Reward drops to zero, model outputs garbage
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase KL penalty**:
|
||||
```bash
|
||||
--kl-loss-coef 0.01 # Increase from 0.001
|
||||
```
|
||||
|
||||
2. **Reduce number of samples**:
|
||||
```bash
|
||||
--n-samples-per-prompt 4 # Reduce from 8
|
||||
```
|
||||
|
||||
3. **Verify reward function**:
|
||||
```python
|
||||
# Test reward function independently
|
||||
from custom_rm import reward_func
|
||||
sample = Sample(prompt="test", response="test response")
|
||||
reward = reward_func(args, sample)
|
||||
print(f"Reward: {reward}") # Should be reasonable
|
||||
```
|
||||
|
||||
### Async Training Issues
|
||||
|
||||
#### Issue: Async Training Not Supported with Colocate
|
||||
|
||||
**Symptoms**: Error when using `--colocate` with `train_async.py`
|
||||
|
||||
**Solution**: Colocated mode is NOT supported for async training. Use separate GPUs:
|
||||
```bash
|
||||
# Remove --colocate flag
|
||||
python train_async.py \
|
||||
--actor-num-gpus-per-node 4 \
|
||||
--rollout-num-gpus 4 \
|
||||
# No --colocate
|
||||
```
|
||||
|
||||
#### Issue: Stale Weights in Async Mode
|
||||
|
||||
**Symptoms**: Policy divergence, inconsistent behavior
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce async buffer size**:
|
||||
```bash
|
||||
--async-buffer-size 2 # Reduce from 4
|
||||
```
|
||||
|
||||
2. **Increase weight update frequency**:
|
||||
```bash
|
||||
--update-weights-interval 1 # Sync every rollout
|
||||
```
|
||||
|
||||
### Multi-Turn Training Issues
|
||||
|
||||
#### Issue: Tool Responses Included in Loss
|
||||
|
||||
**Symptoms**: Model learns to output tool responses verbatim
|
||||
|
||||
**Solution**: Properly set loss mask in custom generate function:
|
||||
```python
|
||||
def build_loss_mask(sample):
|
||||
"""Create loss mask that excludes tool responses."""
|
||||
mask = []
|
||||
for i, token in enumerate(sample.tokens):
|
||||
if is_tool_response(token, sample.metadata):
|
||||
mask.append(0) # Don't compute loss
|
||||
else:
|
||||
mask.append(1) # Compute loss
|
||||
return mask
|
||||
```
|
||||
|
||||
#### Issue: Multi-Turn Context Too Long
|
||||
|
||||
**Symptoms**: OOM or truncation in multi-turn conversations
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Limit conversation history**:
|
||||
```python
|
||||
# In custom generate function
|
||||
conversation = sample.prompt[-10:] # Keep last 10 turns
|
||||
```
|
||||
|
||||
2. **Increase context length**:
|
||||
```bash
|
||||
--sglang-context-length 16384
|
||||
```
|
||||
|
||||
### Checkpoint Issues
|
||||
|
||||
#### Issue: Checkpoint Loading Fails
|
||||
|
||||
**Symptoms**: Cannot load saved checkpoint
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Verify checkpoint path**:
|
||||
```bash
|
||||
ls -la /path/to/checkpoint/
|
||||
```
|
||||
|
||||
2. **Check parallelism matches**:
|
||||
```bash
|
||||
# Checkpoint was saved with TP=2, must load with TP=2
|
||||
--tensor-model-parallel-size 2
|
||||
```
|
||||
|
||||
3. **Convert HuggingFace to Megatron** (if needed):
|
||||
```bash
|
||||
python tools/convert_hf_to_megatron.py \
|
||||
--hf_model_path /path/to/hf/model \
|
||||
--save_path /path/to/megatron/checkpoint
|
||||
```
|
||||
|
||||
### Debugging Tips
|
||||
|
||||
#### Enable Verbose Logging
|
||||
|
||||
```bash
|
||||
--log-level DEBUG
|
||||
export SLIME_DEBUG=1
|
||||
```
|
||||
|
||||
#### Check GPU Utilization
|
||||
|
||||
```bash
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
#### Monitor Training
|
||||
|
||||
```bash
|
||||
tensorboard --logdir outputs/
|
||||
```
|
||||
|
||||
#### Test Custom Functions Independently
|
||||
|
||||
```python
|
||||
# Test reward function
|
||||
import asyncio
|
||||
from custom_rm import reward_func
|
||||
|
||||
async def test():
|
||||
sample = Sample(prompt="test", response="test", label="expected")
|
||||
reward = await reward_func(args, sample)
|
||||
print(f"Reward: {reward}")
|
||||
|
||||
asyncio.run(test())
|
||||
```
|
||||
|
||||
## Constraint Reference
|
||||
|
||||
Key constraint to remember:
|
||||
|
||||
```
|
||||
rollout_batch_size × n_samples_per_prompt = global_batch_size × num_steps_per_rollout
|
||||
```
|
||||
|
||||
Example: `32 × 8 = 256 × 1`
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub Issues: https://github.com/THUDM/slime/issues
|
||||
- Documentation: https://thudm.github.io/slime/
|
||||
- Examples: `examples/` directory
|
||||
522
skills/mlops/stable-diffusion/SKILL.md
Normal file
522
skills/mlops/stable-diffusion/SKILL.md
Normal file
|
|
@ -0,0 +1,522 @@
|
|||
---
|
||||
name: stable-diffusion-image-generation
|
||||
description: State-of-the-art text-to-image generation with Stable Diffusion models via HuggingFace Diffusers. Use when generating images from text prompts, performing image-to-image translation, inpainting, or building custom diffusion pipelines.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [diffusers>=0.30.0, transformers>=4.41.0, accelerate>=0.31.0, torch>=2.0.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Image Generation, Stable Diffusion, Diffusers, Text-to-Image, Multimodal, Computer Vision]
|
||||
|
||||
---
|
||||
|
||||
# Stable Diffusion Image Generation
|
||||
|
||||
Comprehensive guide to generating images with Stable Diffusion using the HuggingFace Diffusers library.
|
||||
|
||||
## When to use Stable Diffusion
|
||||
|
||||
**Use Stable Diffusion when:**
|
||||
- Generating images from text descriptions
|
||||
- Performing image-to-image translation (style transfer, enhancement)
|
||||
- Inpainting (filling in masked regions)
|
||||
- Outpainting (extending images beyond boundaries)
|
||||
- Creating variations of existing images
|
||||
- Building custom image generation workflows
|
||||
|
||||
**Key features:**
|
||||
- **Text-to-Image**: Generate images from natural language prompts
|
||||
- **Image-to-Image**: Transform existing images with text guidance
|
||||
- **Inpainting**: Fill masked regions with context-aware content
|
||||
- **ControlNet**: Add spatial conditioning (edges, poses, depth)
|
||||
- **LoRA Support**: Efficient fine-tuning and style adaptation
|
||||
- **Multiple Models**: SD 1.5, SDXL, SD 3.0, Flux support
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **DALL-E 3**: For API-based generation without GPU
|
||||
- **Midjourney**: For artistic, stylized outputs
|
||||
- **Imagen**: For Google Cloud integration
|
||||
- **Leonardo.ai**: For web-based creative workflows
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install diffusers transformers accelerate torch
|
||||
pip install xformers # Optional: memory-efficient attention
|
||||
```
|
||||
|
||||
### Basic text-to-image
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
# Load pipeline (auto-detects model type)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Generate image
|
||||
image = pipe(
|
||||
"A serene mountain landscape at sunset, highly detailed",
|
||||
num_inference_steps=50,
|
||||
guidance_scale=7.5
|
||||
).images[0]
|
||||
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
### Using SDXL (higher quality)
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Enable memory optimization
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
image = pipe(
|
||||
prompt="A futuristic city with flying cars, cinematic lighting",
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=30
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Architecture overview
|
||||
|
||||
### Three-pillar design
|
||||
|
||||
Diffusers is built around three core components:
|
||||
|
||||
```
|
||||
Pipeline (orchestration)
|
||||
├── Model (neural networks)
|
||||
│ ├── UNet / Transformer (noise prediction)
|
||||
│ ├── VAE (latent encoding/decoding)
|
||||
│ └── Text Encoder (CLIP/T5)
|
||||
└── Scheduler (denoising algorithm)
|
||||
```
|
||||
|
||||
### Pipeline inference flow
|
||||
|
||||
```
|
||||
Text Prompt → Text Encoder → Text Embeddings
|
||||
↓
|
||||
Random Noise → [Denoising Loop] ← Scheduler
|
||||
↓
|
||||
Predicted Noise
|
||||
↓
|
||||
VAE Decoder → Final Image
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Pipelines
|
||||
|
||||
Pipelines orchestrate complete workflows:
|
||||
|
||||
| Pipeline | Purpose |
|
||||
|----------|---------|
|
||||
| `StableDiffusionPipeline` | Text-to-image (SD 1.x/2.x) |
|
||||
| `StableDiffusionXLPipeline` | Text-to-image (SDXL) |
|
||||
| `StableDiffusion3Pipeline` | Text-to-image (SD 3.0) |
|
||||
| `FluxPipeline` | Text-to-image (Flux models) |
|
||||
| `StableDiffusionImg2ImgPipeline` | Image-to-image |
|
||||
| `StableDiffusionInpaintPipeline` | Inpainting |
|
||||
|
||||
### Schedulers
|
||||
|
||||
Schedulers control the denoising process:
|
||||
|
||||
| Scheduler | Steps | Quality | Use Case |
|
||||
|-----------|-------|---------|----------|
|
||||
| `EulerDiscreteScheduler` | 20-50 | Good | Default choice |
|
||||
| `EulerAncestralDiscreteScheduler` | 20-50 | Good | More variation |
|
||||
| `DPMSolverMultistepScheduler` | 15-25 | Excellent | Fast, high quality |
|
||||
| `DDIMScheduler` | 50-100 | Good | Deterministic |
|
||||
| `LCMScheduler` | 4-8 | Good | Very fast |
|
||||
| `UniPCMultistepScheduler` | 15-25 | Excellent | Fast convergence |
|
||||
|
||||
### Swapping schedulers
|
||||
|
||||
```python
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
# Swap for faster generation
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipe.scheduler.config
|
||||
)
|
||||
|
||||
# Now generate with fewer steps
|
||||
image = pipe(prompt, num_inference_steps=20).images[0]
|
||||
```
|
||||
|
||||
## Generation parameters
|
||||
|
||||
### Key parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `prompt` | Required | Text description of desired image |
|
||||
| `negative_prompt` | None | What to avoid in the image |
|
||||
| `num_inference_steps` | 50 | Denoising steps (more = better quality) |
|
||||
| `guidance_scale` | 7.5 | Prompt adherence (7-12 typical) |
|
||||
| `height`, `width` | 512/1024 | Output dimensions (multiples of 8) |
|
||||
| `generator` | None | Torch generator for reproducibility |
|
||||
| `num_images_per_prompt` | 1 | Batch size |
|
||||
|
||||
### Reproducible generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(42)
|
||||
|
||||
image = pipe(
|
||||
prompt="A cat wearing a top hat",
|
||||
generator=generator,
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Negative prompts
|
||||
|
||||
```python
|
||||
image = pipe(
|
||||
prompt="Professional photo of a dog in a garden",
|
||||
negative_prompt="blurry, low quality, distorted, ugly, bad anatomy",
|
||||
guidance_scale=7.5
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Image-to-image
|
||||
|
||||
Transform existing images with text guidance:
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
from PIL import Image
|
||||
|
||||
pipe = AutoPipelineForImage2Image.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
init_image = Image.open("input.jpg").resize((512, 512))
|
||||
|
||||
image = pipe(
|
||||
prompt="A watercolor painting of the scene",
|
||||
image=init_image,
|
||||
strength=0.75, # How much to transform (0-1)
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Inpainting
|
||||
|
||||
Fill masked regions:
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from PIL import Image
|
||||
|
||||
pipe = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = Image.open("photo.jpg")
|
||||
mask = Image.open("mask.png") # White = inpaint region
|
||||
|
||||
result = pipe(
|
||||
prompt="A red car parked on the street",
|
||||
image=image,
|
||||
mask_image=mask,
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## ControlNet
|
||||
|
||||
Add spatial conditioning for precise control:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
import torch
|
||||
|
||||
# Load ControlNet for edge conditioning
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Use Canny edge image as control
|
||||
control_image = get_canny_image(input_image)
|
||||
|
||||
image = pipe(
|
||||
prompt="A beautiful house in the style of Van Gogh",
|
||||
image=control_image,
|
||||
num_inference_steps=30
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Available ControlNets
|
||||
|
||||
| ControlNet | Input Type | Use Case |
|
||||
|------------|------------|----------|
|
||||
| `canny` | Edge maps | Preserve structure |
|
||||
| `openpose` | Pose skeletons | Human poses |
|
||||
| `depth` | Depth maps | 3D-aware generation |
|
||||
| `normal` | Normal maps | Surface details |
|
||||
| `mlsd` | Line segments | Architectural lines |
|
||||
| `scribble` | Rough sketches | Sketch-to-image |
|
||||
|
||||
## LoRA adapters
|
||||
|
||||
Load fine-tuned style adapters:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Load LoRA weights
|
||||
pipe.load_lora_weights("path/to/lora", weight_name="style.safetensors")
|
||||
|
||||
# Generate with LoRA style
|
||||
image = pipe("A portrait in the trained style").images[0]
|
||||
|
||||
# Adjust LoRA strength
|
||||
pipe.fuse_lora(lora_scale=0.8)
|
||||
|
||||
# Unload LoRA
|
||||
pipe.unload_lora_weights()
|
||||
```
|
||||
|
||||
### Multiple LoRAs
|
||||
|
||||
```python
|
||||
# Load multiple LoRAs
|
||||
pipe.load_lora_weights("lora1", adapter_name="style")
|
||||
pipe.load_lora_weights("lora2", adapter_name="character")
|
||||
|
||||
# Set weights for each
|
||||
pipe.set_adapters(["style", "character"], adapter_weights=[0.7, 0.5])
|
||||
|
||||
image = pipe("A portrait").images[0]
|
||||
```
|
||||
|
||||
## Memory optimization
|
||||
|
||||
### Enable CPU offloading
|
||||
|
||||
```python
|
||||
# Model CPU offload - moves models to CPU when not in use
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Sequential CPU offload - more aggressive, slower
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
### Attention slicing
|
||||
|
||||
```python
|
||||
# Reduce memory by computing attention in chunks
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
# Or specific chunk size
|
||||
pipe.enable_attention_slicing("max")
|
||||
```
|
||||
|
||||
### xFormers memory-efficient attention
|
||||
|
||||
```python
|
||||
# Requires xformers package
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
```
|
||||
|
||||
### VAE slicing for large images
|
||||
|
||||
```python
|
||||
# Decode latents in tiles for large images
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.enable_vae_tiling()
|
||||
```
|
||||
|
||||
## Model variants
|
||||
|
||||
### Loading different precisions
|
||||
|
||||
```python
|
||||
# FP16 (recommended for GPU)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"model-id",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
)
|
||||
|
||||
# BF16 (better precision, requires Ampere+ GPU)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"model-id",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
### Loading specific components
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel, AutoencoderKL
|
||||
|
||||
# Load custom VAE
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
|
||||
|
||||
# Use with pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
vae=vae,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
## Batch generation
|
||||
|
||||
Generate multiple images efficiently:
|
||||
|
||||
```python
|
||||
# Multiple prompts
|
||||
prompts = [
|
||||
"A cat playing piano",
|
||||
"A dog reading a book",
|
||||
"A bird painting a picture"
|
||||
]
|
||||
|
||||
images = pipe(prompts, num_inference_steps=30).images
|
||||
|
||||
# Multiple images per prompt
|
||||
images = pipe(
|
||||
"A beautiful sunset",
|
||||
num_images_per_prompt=4,
|
||||
num_inference_steps=30
|
||||
).images
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: High-quality generation
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
|
||||
# 1. Load SDXL with optimizations
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# 2. Generate with quality settings
|
||||
image = pipe(
|
||||
prompt="A majestic lion in the savanna, golden hour lighting, 8k, detailed fur",
|
||||
negative_prompt="blurry, low quality, cartoon, anime, sketch",
|
||||
num_inference_steps=30,
|
||||
guidance_scale=7.5,
|
||||
height=1024,
|
||||
width=1024
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Workflow 2: Fast prototyping
|
||||
|
||||
```python
|
||||
from diffusers import AutoPipelineForText2Image, LCMScheduler
|
||||
import torch
|
||||
|
||||
# Use LCM for 4-8 step generation
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Load LCM LoRA for fast generation
|
||||
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.fuse_lora()
|
||||
|
||||
# Generate in ~1 second
|
||||
image = pipe(
|
||||
"A beautiful landscape",
|
||||
num_inference_steps=4,
|
||||
guidance_scale=1.0
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Common issues
|
||||
|
||||
**CUDA out of memory:**
|
||||
```python
|
||||
# Enable memory optimizations
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
# Or use lower precision
|
||||
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
**Black/noise images:**
|
||||
```python
|
||||
# Check VAE configuration
|
||||
# Use safety checker bypass if needed
|
||||
pipe.safety_checker = None
|
||||
|
||||
# Ensure proper dtype consistency
|
||||
pipe = pipe.to(dtype=torch.float16)
|
||||
```
|
||||
|
||||
**Slow generation:**
|
||||
```python
|
||||
# Use faster scheduler
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# Reduce steps
|
||||
image = pipe(prompt, num_inference_steps=20).images[0]
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Custom pipelines, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://huggingface.co/docs/diffusers
|
||||
- **Repository**: https://github.com/huggingface/diffusers
|
||||
- **Model Hub**: https://huggingface.co/models?library=diffusers
|
||||
- **Discord**: https://discord.gg/diffusers
|
||||
716
skills/mlops/stable-diffusion/references/advanced-usage.md
Normal file
716
skills/mlops/stable-diffusion/references/advanced-usage.md
Normal file
|
|
@ -0,0 +1,716 @@
|
|||
# Stable Diffusion Advanced Usage Guide
|
||||
|
||||
## Custom Pipelines
|
||||
|
||||
### Building from components
|
||||
|
||||
```python
|
||||
from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionPipeline
|
||||
)
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import torch
|
||||
|
||||
# Load components individually
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
subfolder="unet"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
subfolder="vae"
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
subfolder="text_encoder"
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
subfolder="tokenizer"
|
||||
)
|
||||
scheduler = DDPMScheduler.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
subfolder="scheduler"
|
||||
)
|
||||
|
||||
# Assemble pipeline
|
||||
pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False
|
||||
)
|
||||
```
|
||||
|
||||
### Custom denoising loop
|
||||
|
||||
```python
|
||||
from diffusers import DDIMScheduler, AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import torch
|
||||
|
||||
def custom_generate(
|
||||
prompt: str,
|
||||
num_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
height: int = 512,
|
||||
width: int = 512
|
||||
):
|
||||
# Load components
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
unet = UNet2DConditionModel.from_pretrained("sd-model", subfolder="unet")
|
||||
vae = AutoencoderKL.from_pretrained("sd-model", subfolder="vae")
|
||||
scheduler = DDIMScheduler.from_pretrained("sd-model", subfolder="scheduler")
|
||||
|
||||
device = "cuda"
|
||||
text_encoder.to(device)
|
||||
unet.to(device)
|
||||
vae.to(device)
|
||||
|
||||
# Encode prompt
|
||||
text_input = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
|
||||
|
||||
# Unconditional embeddings for classifier-free guidance
|
||||
uncond_input = tokenizer(
|
||||
"",
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
|
||||
|
||||
# Concatenate for batch processing
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# Initialize latents
|
||||
latents = torch.randn(
|
||||
(1, 4, height // 8, width // 8),
|
||||
device=device
|
||||
)
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
|
||||
# Denoising loop
|
||||
scheduler.set_timesteps(num_steps)
|
||||
for t in scheduler.timesteps:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# Predict noise
|
||||
with torch.no_grad():
|
||||
noise_pred = unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings
|
||||
).sample
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_cond - noise_pred_uncond
|
||||
)
|
||||
|
||||
# Update latents
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
# Decode latents
|
||||
latents = latents / vae.config.scaling_factor
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latents).sample
|
||||
|
||||
# Convert to PIL
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
image = (image * 255).round().astype("uint8")[0]
|
||||
|
||||
return Image.fromarray(image)
|
||||
```
|
||||
|
||||
## IP-Adapter
|
||||
|
||||
Use image prompts alongside text:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Load IP-Adapter
|
||||
pipe.load_ip_adapter(
|
||||
"h94/IP-Adapter",
|
||||
subfolder="models",
|
||||
weight_name="ip-adapter_sd15.bin"
|
||||
)
|
||||
|
||||
# Set IP-Adapter scale
|
||||
pipe.set_ip_adapter_scale(0.6)
|
||||
|
||||
# Load reference image
|
||||
ip_image = load_image("reference_style.jpg")
|
||||
|
||||
# Generate with image + text prompt
|
||||
image = pipe(
|
||||
prompt="A portrait in a garden",
|
||||
ip_adapter_image=ip_image,
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Multiple IP-Adapter images
|
||||
|
||||
```python
|
||||
# Use multiple reference images
|
||||
pipe.set_ip_adapter_scale([0.5, 0.7])
|
||||
|
||||
images = [
|
||||
load_image("style_reference.jpg"),
|
||||
load_image("composition_reference.jpg")
|
||||
]
|
||||
|
||||
result = pipe(
|
||||
prompt="A landscape painting",
|
||||
ip_adapter_image=images,
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## SDXL Refiner
|
||||
|
||||
Two-stage generation for higher quality:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
base = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
).to("cuda")
|
||||
|
||||
# Load refiner
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16"
|
||||
).to("cuda")
|
||||
|
||||
# Generate with base (partial denoising)
|
||||
image = base(
|
||||
prompt="A majestic eagle soaring over mountains",
|
||||
num_inference_steps=40,
|
||||
denoising_end=0.8,
|
||||
output_type="latent"
|
||||
).images
|
||||
|
||||
# Refine with refiner
|
||||
refined = refiner(
|
||||
prompt="A majestic eagle soaring over mountains",
|
||||
image=image,
|
||||
num_inference_steps=40,
|
||||
denoising_start=0.8
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## T2I-Adapter
|
||||
|
||||
Lightweight conditioning without full ControlNet:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter
|
||||
import torch
|
||||
|
||||
# Load adapter
|
||||
adapter = T2IAdapter.from_pretrained(
|
||||
"TencentARC/t2i-adapter-canny-sdxl-1.0",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Get canny edges
|
||||
canny_image = get_canny_image(input_image)
|
||||
|
||||
image = pipe(
|
||||
prompt="A colorful anime character",
|
||||
image=canny_image,
|
||||
num_inference_steps=30,
|
||||
adapter_conditioning_scale=0.8
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Fine-tuning with DreamBooth
|
||||
|
||||
Train on custom subjects:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline, DDPMScheduler
|
||||
from diffusers.optimization import get_scheduler
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
def __init__(self, instance_images_path, instance_prompt, tokenizer, size=512):
|
||||
self.instance_images_path = instance_images_path
|
||||
self.instance_prompt = instance_prompt
|
||||
self.tokenizer = tokenizer
|
||||
self.size = size
|
||||
|
||||
self.instance_images = [
|
||||
os.path.join(instance_images_path, f)
|
||||
for f in os.listdir(instance_images_path)
|
||||
if f.endswith(('.png', '.jpg', '.jpeg'))
|
||||
]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.instance_images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image = Image.open(self.instance_images[idx]).convert("RGB")
|
||||
image = image.resize((self.size, self.size))
|
||||
image = torch.tensor(np.array(image)).permute(2, 0, 1) / 127.5 - 1.0
|
||||
|
||||
tokens = self.tokenizer(
|
||||
self.instance_prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
return {"image": image, "input_ids": tokens.input_ids.squeeze()}
|
||||
|
||||
def train_dreambooth(
|
||||
pretrained_model: str,
|
||||
instance_data_dir: str,
|
||||
instance_prompt: str,
|
||||
output_dir: str,
|
||||
learning_rate: float = 5e-6,
|
||||
max_train_steps: int = 800,
|
||||
train_batch_size: int = 1
|
||||
):
|
||||
# Load pipeline
|
||||
pipe = StableDiffusionPipeline.from_pretrained(pretrained_model)
|
||||
|
||||
unet = pipe.unet
|
||||
vae = pipe.vae
|
||||
text_encoder = pipe.text_encoder
|
||||
tokenizer = pipe.tokenizer
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model, subfolder="scheduler")
|
||||
|
||||
# Freeze VAE and text encoder
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
# Create dataset
|
||||
dataset = DreamBoothDataset(
|
||||
instance_data_dir, instance_prompt, tokenizer
|
||||
)
|
||||
dataloader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
|
||||
lr_scheduler = get_scheduler(
|
||||
"constant",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=max_train_steps
|
||||
)
|
||||
|
||||
# Training loop
|
||||
unet.train()
|
||||
device = "cuda"
|
||||
unet.to(device)
|
||||
vae.to(device)
|
||||
text_encoder.to(device)
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(max_train_steps // len(dataloader) + 1):
|
||||
for batch in dataloader:
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
# Encode images to latents
|
||||
latents = vae.encode(batch["image"].to(device)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise
|
||||
noise = torch.randn_like(latents)
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],))
|
||||
timesteps = timesteps.to(device)
|
||||
|
||||
# Add noise
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get text embeddings
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0]
|
||||
|
||||
# Predict noise
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Compute loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, noise)
|
||||
|
||||
# Backprop
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
|
||||
if global_step % 100 == 0:
|
||||
print(f"Step {global_step}, Loss: {loss.item():.4f}")
|
||||
|
||||
# Save model
|
||||
pipe.unet = unet
|
||||
pipe.save_pretrained(output_dir)
|
||||
```
|
||||
|
||||
## LoRA Training
|
||||
|
||||
Efficient fine-tuning with Low-Rank Adaptation:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
|
||||
def train_lora(
|
||||
base_model: str,
|
||||
train_dataset,
|
||||
output_dir: str,
|
||||
lora_rank: int = 4,
|
||||
learning_rate: float = 1e-4,
|
||||
max_train_steps: int = 1000
|
||||
):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(base_model)
|
||||
unet = pipe.unet
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_rank,
|
||||
target_modules=["to_q", "to_v", "to_k", "to_out.0"],
|
||||
lora_dropout=0.1
|
||||
)
|
||||
|
||||
# Apply LoRA to UNet
|
||||
unet = get_peft_model(unet, lora_config)
|
||||
unet.print_trainable_parameters() # Shows ~0.1% trainable
|
||||
|
||||
# Train (similar to DreamBooth but only LoRA params)
|
||||
optimizer = torch.optim.AdamW(
|
||||
unet.parameters(),
|
||||
lr=learning_rate
|
||||
)
|
||||
|
||||
# ... training loop ...
|
||||
|
||||
# Save LoRA weights only
|
||||
unet.save_pretrained(output_dir)
|
||||
```
|
||||
|
||||
## Textual Inversion
|
||||
|
||||
Learn new concepts through embeddings:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
|
||||
# Load with textual inversion
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
# Load learned embedding
|
||||
pipe.load_textual_inversion(
|
||||
"sd-concepts-library/cat-toy",
|
||||
token="<cat-toy>"
|
||||
)
|
||||
|
||||
# Use in prompts
|
||||
image = pipe("A photo of <cat-toy> on a beach").images[0]
|
||||
```
|
||||
|
||||
## Quantization
|
||||
|
||||
Reduce memory with quantization:
|
||||
|
||||
```python
|
||||
from diffusers import BitsAndBytesConfig, StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
# 8-bit quantization
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
### NF4 quantization (4-bit)
|
||||
|
||||
```python
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
class GenerationRequest(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: str = ""
|
||||
num_inference_steps: int = 30
|
||||
guidance_scale: float = 7.5
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
seed: int = None
|
||||
|
||||
class GenerationResponse(BaseModel):
|
||||
image_base64: str
|
||||
seed: int
|
||||
|
||||
@app.post("/generate", response_model=GenerationResponse)
|
||||
async def generate(request: GenerationRequest):
|
||||
try:
|
||||
generator = None
|
||||
seed = request.seed or torch.randint(0, 2**32, (1,)).item()
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
|
||||
image = pipe(
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
num_inference_steps=request.num_inference_steps,
|
||||
guidance_scale=request.guidance_scale,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
generator=generator
|
||||
).images[0]
|
||||
|
||||
# Convert to base64
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
image_base64 = base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
return GenerationResponse(image_base64=image_base64, seed=seed)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
```
|
||||
|
||||
### Docker deployment
|
||||
|
||||
```dockerfile
|
||||
FROM nvidia/cuda:12.1-runtime-ubuntu22.04
|
||||
|
||||
RUN apt-get update && apt-get install -y python3 python3-pip
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip3 install -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
# Pre-download model
|
||||
RUN python3 -c "from diffusers import DiffusionPipeline; DiffusionPipeline.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5')"
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
```
|
||||
|
||||
### Kubernetes deployment
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: stable-diffusion
|
||||
spec:
|
||||
replicas: 2
|
||||
selector:
|
||||
matchLabels:
|
||||
app: stable-diffusion
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: stable-diffusion
|
||||
spec:
|
||||
containers:
|
||||
- name: sd
|
||||
image: your-registry/stable-diffusion:latest
|
||||
ports:
|
||||
- containerPort: 8000
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 1
|
||||
memory: "16Gi"
|
||||
requests:
|
||||
nvidia.com/gpu: 1
|
||||
memory: "8Gi"
|
||||
env:
|
||||
- name: TRANSFORMERS_CACHE
|
||||
value: "/cache/huggingface"
|
||||
volumeMounts:
|
||||
- name: model-cache
|
||||
mountPath: /cache
|
||||
volumes:
|
||||
- name: model-cache
|
||||
persistentVolumeClaim:
|
||||
claimName: model-cache-pvc
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: stable-diffusion
|
||||
spec:
|
||||
selector:
|
||||
app: stable-diffusion
|
||||
ports:
|
||||
- port: 80
|
||||
targetPort: 8000
|
||||
type: LoadBalancer
|
||||
```
|
||||
|
||||
## Callback System
|
||||
|
||||
Monitor and modify generation:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.callbacks import PipelineCallback
|
||||
import torch
|
||||
|
||||
class ProgressCallback(PipelineCallback):
|
||||
def __init__(self):
|
||||
self.progress = []
|
||||
|
||||
def callback_fn(self, pipe, step_index, timestep, callback_kwargs):
|
||||
self.progress.append({
|
||||
"step": step_index,
|
||||
"timestep": timestep.item()
|
||||
})
|
||||
|
||||
# Optionally modify latents
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
# Use callback
|
||||
callback = ProgressCallback()
|
||||
|
||||
image = pipe(
|
||||
prompt="A sunset",
|
||||
callback_on_step_end=callback.callback_fn,
|
||||
callback_on_step_end_tensor_inputs=["latents"]
|
||||
).images[0]
|
||||
|
||||
print(f"Generation completed in {len(callback.progress)} steps")
|
||||
```
|
||||
|
||||
### Early stopping
|
||||
|
||||
```python
|
||||
def early_stop_callback(pipe, step_index, timestep, callback_kwargs):
|
||||
# Stop after 20 steps
|
||||
if step_index >= 20:
|
||||
pipe._interrupt = True
|
||||
return callback_kwargs
|
||||
|
||||
image = pipe(
|
||||
prompt="A landscape",
|
||||
num_inference_steps=50,
|
||||
callback_on_step_end=early_stop_callback
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Multi-GPU Inference
|
||||
|
||||
### Device map auto
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
device_map="auto", # Automatically distribute across GPUs
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
### Manual distribution
|
||||
|
||||
```python
|
||||
from accelerate import infer_auto_device_map, dispatch_model
|
||||
|
||||
# Create device map
|
||||
device_map = infer_auto_device_map(
|
||||
pipe.unet,
|
||||
max_memory={0: "10GiB", 1: "10GiB"}
|
||||
)
|
||||
|
||||
# Dispatch model
|
||||
pipe.unet = dispatch_model(pipe.unet, device_map=device_map)
|
||||
```
|
||||
555
skills/mlops/stable-diffusion/references/troubleshooting.md
Normal file
555
skills/mlops/stable-diffusion/references/troubleshooting.md
Normal file
|
|
@ -0,0 +1,555 @@
|
|||
# Stable Diffusion Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Package conflicts
|
||||
|
||||
**Error**: `ImportError: cannot import name 'cached_download' from 'huggingface_hub'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Update huggingface_hub
|
||||
pip install --upgrade huggingface_hub
|
||||
|
||||
# Reinstall diffusers
|
||||
pip install --upgrade diffusers
|
||||
```
|
||||
|
||||
### xFormers installation fails
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available for execution`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
|
||||
# Install matching xformers
|
||||
pip install xformers --index-url https://download.pytorch.org/whl/cu121 # For CUDA 12.1
|
||||
|
||||
# Or build from source
|
||||
pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
|
||||
```
|
||||
|
||||
### Torch/CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check versions
|
||||
python -c "import torch; print(torch.__version__, torch.cuda.is_available())"
|
||||
|
||||
# Reinstall PyTorch with correct CUDA
|
||||
pip uninstall torch torchvision
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
|
||||
```python
|
||||
# Solution 1: Enable CPU offloading
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Solution 2: Sequential CPU offload (more aggressive)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
# Solution 3: Attention slicing
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
# Solution 4: VAE slicing for large images
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
# Solution 5: Use lower precision
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"model-id",
|
||||
torch_dtype=torch.float16 # or torch.bfloat16
|
||||
)
|
||||
|
||||
# Solution 6: Reduce batch size
|
||||
image = pipe(prompt, num_images_per_prompt=1).images[0]
|
||||
|
||||
# Solution 7: Generate smaller images
|
||||
image = pipe(prompt, height=512, width=512).images[0]
|
||||
|
||||
# Solution 8: Clear cache between generations
|
||||
import gc
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
```
|
||||
|
||||
### Memory grows over time
|
||||
|
||||
**Problem**: Memory usage increases with each generation
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(pipe, prompt, **kwargs):
|
||||
try:
|
||||
image = pipe(prompt, **kwargs).images[0]
|
||||
return image
|
||||
finally:
|
||||
# Clear cache after generation
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
```
|
||||
|
||||
### Large model loading fails
|
||||
|
||||
**Error**: `RuntimeError: Unable to load model weights`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Use low CPU memory mode
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"large-model-id",
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Black images
|
||||
|
||||
**Problem**: Output images are completely black
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Disable safety checker
|
||||
pipe.safety_checker = None
|
||||
|
||||
# Solution 2: Check VAE scaling
|
||||
# The issue might be with VAE encoding/decoding
|
||||
latents = latents / pipe.vae.config.scaling_factor # Before decode
|
||||
|
||||
# Solution 3: Ensure proper dtype
|
||||
pipe = pipe.to(dtype=torch.float16)
|
||||
pipe.vae = pipe.vae.to(dtype=torch.float32) # VAE often needs fp32
|
||||
|
||||
# Solution 4: Check guidance scale
|
||||
# Too high can cause issues
|
||||
image = pipe(prompt, guidance_scale=7.5).images[0] # Not 20+
|
||||
```
|
||||
|
||||
### Noise/static images
|
||||
|
||||
**Problem**: Output looks like random noise
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Increase inference steps
|
||||
image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
|
||||
# Solution 2: Check scheduler configuration
|
||||
pipe.scheduler = pipe.scheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# Solution 3: Verify model was loaded correctly
|
||||
print(pipe.unet) # Should show model architecture
|
||||
```
|
||||
|
||||
### Blurry images
|
||||
|
||||
**Problem**: Output images are low quality or blurry
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Use more steps
|
||||
image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
|
||||
# Solution 2: Use better VAE
|
||||
from diffusers import AutoencoderKL
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
|
||||
pipe.vae = vae
|
||||
|
||||
# Solution 3: Use SDXL or refiner
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
)
|
||||
|
||||
# Solution 4: Upscale with img2img
|
||||
upscale_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(...)
|
||||
upscaled = upscale_pipe(
|
||||
prompt=prompt,
|
||||
image=image.resize((1024, 1024)),
|
||||
strength=0.3
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Prompt not being followed
|
||||
|
||||
**Problem**: Generated image doesn't match the prompt
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Increase guidance scale
|
||||
image = pipe(prompt, guidance_scale=10.0).images[0]
|
||||
|
||||
# Solution 2: Use negative prompts
|
||||
image = pipe(
|
||||
prompt="A red car",
|
||||
negative_prompt="blue, green, yellow, wrong color",
|
||||
guidance_scale=7.5
|
||||
).images[0]
|
||||
|
||||
# Solution 3: Use prompt weighting
|
||||
# Emphasize important words
|
||||
prompt = "A (red:1.5) car on a street"
|
||||
|
||||
# Solution 4: Use longer, more detailed prompts
|
||||
prompt = """
|
||||
A bright red sports car, ferrari style, parked on a city street,
|
||||
photorealistic, high detail, 8k, professional photography
|
||||
"""
|
||||
```
|
||||
|
||||
### Distorted faces/hands
|
||||
|
||||
**Problem**: Faces and hands look deformed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Use negative prompts
|
||||
negative_prompt = """
|
||||
bad hands, bad anatomy, deformed, ugly, blurry,
|
||||
extra fingers, mutated hands, poorly drawn hands,
|
||||
poorly drawn face, mutation, deformed face
|
||||
"""
|
||||
|
||||
# Solution 2: Use face-specific models
|
||||
# ADetailer or similar post-processing
|
||||
|
||||
# Solution 3: Use ControlNet for poses
|
||||
# Load pose estimation and condition generation
|
||||
|
||||
# Solution 4: Inpaint problematic areas
|
||||
mask = create_face_mask(image)
|
||||
fixed = inpaint_pipe(
|
||||
prompt="beautiful detailed face",
|
||||
image=image,
|
||||
mask_image=mask
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Scheduler Issues
|
||||
|
||||
### Scheduler not compatible
|
||||
|
||||
**Error**: `ValueError: Scheduler ... is not compatible with pipeline`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
# Create scheduler from config
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(
|
||||
pipe.scheduler.config
|
||||
)
|
||||
|
||||
# Check compatible schedulers
|
||||
print(pipe.scheduler.compatibles)
|
||||
```
|
||||
|
||||
### Wrong number of steps
|
||||
|
||||
**Problem**: Model generates different quality with same steps
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Reset timesteps explicitly
|
||||
pipe.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Check scheduler's step count
|
||||
print(len(pipe.scheduler.timesteps))
|
||||
```
|
||||
|
||||
## LoRA Issues
|
||||
|
||||
### LoRA weights not loading
|
||||
|
||||
**Error**: `RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check weight file format
|
||||
# Should be .safetensors or .bin
|
||||
|
||||
# Load with correct key prefix
|
||||
pipe.load_lora_weights(
|
||||
"path/to/lora",
|
||||
weight_name="lora.safetensors"
|
||||
)
|
||||
|
||||
# Try loading into specific component
|
||||
pipe.unet.load_attn_procs("path/to/lora")
|
||||
```
|
||||
|
||||
### LoRA not affecting output
|
||||
|
||||
**Problem**: Generated images look the same with/without LoRA
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Fuse LoRA weights
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
|
||||
# Or set scale explicitly
|
||||
pipe.set_adapters(["lora_name"], adapter_weights=[1.0])
|
||||
|
||||
# Verify LoRA is loaded
|
||||
print(list(pipe.unet.attn_processors.keys()))
|
||||
```
|
||||
|
||||
### Multiple LoRAs conflict
|
||||
|
||||
**Problem**: Multiple LoRAs produce artifacts
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Load with different adapter names
|
||||
pipe.load_lora_weights("lora1", adapter_name="style")
|
||||
pipe.load_lora_weights("lora2", adapter_name="subject")
|
||||
|
||||
# Balance weights
|
||||
pipe.set_adapters(
|
||||
["style", "subject"],
|
||||
adapter_weights=[0.5, 0.5] # Lower weights
|
||||
)
|
||||
|
||||
# Or use LoRA merge before loading
|
||||
# Merge LoRAs offline with appropriate ratios
|
||||
```
|
||||
|
||||
## ControlNet Issues
|
||||
|
||||
### ControlNet not conditioning
|
||||
|
||||
**Problem**: ControlNet has no effect on output
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Check control image format
|
||||
# Should be RGB, matching generation size
|
||||
control_image = control_image.resize((512, 512))
|
||||
|
||||
# Increase conditioning scale
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
image=control_image,
|
||||
controlnet_conditioning_scale=1.0, # Try 0.5-1.5
|
||||
num_inference_steps=30
|
||||
).images[0]
|
||||
|
||||
# Verify ControlNet is loaded
|
||||
print(pipe.controlnet)
|
||||
```
|
||||
|
||||
### Control image preprocessing
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
from controlnet_aux import CannyDetector
|
||||
|
||||
# Proper preprocessing
|
||||
canny = CannyDetector()
|
||||
control_image = canny(input_image)
|
||||
|
||||
# Ensure correct format
|
||||
control_image = control_image.convert("RGB")
|
||||
control_image = control_image.resize((512, 512))
|
||||
```
|
||||
|
||||
## Hub/Download Issues
|
||||
|
||||
### Model download fails
|
||||
|
||||
**Error**: `requests.exceptions.ConnectionError`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Set longer timeout
|
||||
export HF_HUB_DOWNLOAD_TIMEOUT=600
|
||||
|
||||
# Use mirror if available
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# Or download manually
|
||||
huggingface-cli download stable-diffusion-v1-5/stable-diffusion-v1-5
|
||||
```
|
||||
|
||||
### Cache issues
|
||||
|
||||
**Error**: `OSError: Can't load model from cache`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Clear cache
|
||||
rm -rf ~/.cache/huggingface/hub
|
||||
|
||||
# Or set different cache location
|
||||
export HF_HOME=/path/to/cache
|
||||
|
||||
# Force re-download
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"model-id",
|
||||
force_download=True
|
||||
)
|
||||
```
|
||||
|
||||
### Access denied for gated models
|
||||
|
||||
**Error**: `401 Client Error: Unauthorized`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Login to Hugging Face
|
||||
huggingface-cli login
|
||||
|
||||
# Or use token
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"model-id",
|
||||
token="hf_xxxxx"
|
||||
)
|
||||
|
||||
# Accept model license on Hub website first
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Solution 1: Use faster scheduler
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipe.scheduler.config
|
||||
)
|
||||
|
||||
# Solution 2: Reduce steps
|
||||
image = pipe(prompt, num_inference_steps=20).images[0]
|
||||
|
||||
# Solution 3: Use LCM
|
||||
from diffusers import LCMScheduler
|
||||
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
image = pipe(prompt, num_inference_steps=4, guidance_scale=1.0).images[0]
|
||||
|
||||
# Solution 4: Enable xFormers
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# Solution 5: Compile model
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
### First generation is slow
|
||||
|
||||
**Problem**: First image takes much longer
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Warm up the model
|
||||
_ = pipe("warmup", num_inference_steps=1)
|
||||
|
||||
# Then run actual generation
|
||||
image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
|
||||
# Compile for faster subsequent runs
|
||||
pipe.unet = torch.compile(pipe.unet)
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
### Enable debug logging
|
||||
|
||||
```python
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Or for specific modules
|
||||
logging.getLogger("diffusers").setLevel(logging.DEBUG)
|
||||
logging.getLogger("transformers").setLevel(logging.DEBUG)
|
||||
```
|
||||
|
||||
### Check model components
|
||||
|
||||
```python
|
||||
# Print pipeline components
|
||||
print(pipe.components)
|
||||
|
||||
# Check model config
|
||||
print(pipe.unet.config)
|
||||
print(pipe.vae.config)
|
||||
print(pipe.scheduler.config)
|
||||
|
||||
# Verify device placement
|
||||
print(pipe.device)
|
||||
for name, module in pipe.components.items():
|
||||
if hasattr(module, 'device'):
|
||||
print(f"{name}: {module.device}")
|
||||
```
|
||||
|
||||
### Validate inputs
|
||||
|
||||
```python
|
||||
# Check image dimensions
|
||||
print(f"Height: {height}, Width: {width}")
|
||||
assert height % 8 == 0, "Height must be divisible by 8"
|
||||
assert width % 8 == 0, "Width must be divisible by 8"
|
||||
|
||||
# Check prompt tokenization
|
||||
tokens = pipe.tokenizer(prompt, return_tensors="pt")
|
||||
print(f"Token count: {tokens.input_ids.shape[1]}") # Max 77 for SD
|
||||
```
|
||||
|
||||
### Save intermediate results
|
||||
|
||||
```python
|
||||
def save_latents_callback(pipe, step_index, timestep, callback_kwargs):
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
# Decode and save intermediate
|
||||
with torch.no_grad():
|
||||
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
||||
Image.fromarray((image * 255).astype("uint8")).save(f"step_{step_index}.png")
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
callback_on_step_end=save_latents_callback,
|
||||
callback_on_step_end_tensor_inputs=["latents"]
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **Documentation**: https://huggingface.co/docs/diffusers
|
||||
2. **GitHub Issues**: https://github.com/huggingface/diffusers/issues
|
||||
3. **Discord**: https://discord.gg/diffusers
|
||||
4. **Forum**: https://discuss.huggingface.co
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Diffusers version: `pip show diffusers`
|
||||
- PyTorch version: `python -c "import torch; print(torch.__version__)"`
|
||||
- CUDA version: `nvcc --version`
|
||||
- GPU model: `nvidia-smi`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Model name/ID used
|
||||
190
skills/mlops/tensorrt-llm/SKILL.md
Normal file
190
skills/mlops/tensorrt-llm/SKILL.md
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
---
|
||||
name: tensorrt-llm
|
||||
description: Optimizes LLM inference with NVIDIA TensorRT for maximum throughput and lowest latency. Use for production deployment on NVIDIA GPUs (A100/H100), when you need 10-100x faster inference than PyTorch, or for serving models with quantization (FP8/INT4), in-flight batching, and multi-GPU scaling.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [tensorrt-llm, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Inference Serving, TensorRT-LLM, NVIDIA, Inference Optimization, High Throughput, Low Latency, Production, FP8, INT4, In-Flight Batching, Multi-GPU]
|
||||
|
||||
---
|
||||
|
||||
# TensorRT-LLM
|
||||
|
||||
NVIDIA's open-source library for optimizing LLM inference with state-of-the-art performance on NVIDIA GPUs.
|
||||
|
||||
## When to use TensorRT-LLM
|
||||
|
||||
**Use TensorRT-LLM when:**
|
||||
- Deploying on NVIDIA GPUs (A100, H100, GB200)
|
||||
- Need maximum throughput (24,000+ tokens/sec on Llama 3)
|
||||
- Require low latency for real-time applications
|
||||
- Working with quantized models (FP8, INT4, FP4)
|
||||
- Scaling across multiple GPUs or nodes
|
||||
|
||||
**Use vLLM instead when:**
|
||||
- Need simpler setup and Python-first API
|
||||
- Want PagedAttention without TensorRT compilation
|
||||
- Working with AMD GPUs or non-NVIDIA hardware
|
||||
|
||||
**Use llama.cpp instead when:**
|
||||
- Deploying on CPU or Apple Silicon
|
||||
- Need edge deployment without NVIDIA GPUs
|
||||
- Want simpler GGUF quantization format
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Docker (recommended)
|
||||
docker pull nvidia/tensorrt_llm:latest
|
||||
|
||||
# pip install
|
||||
pip install tensorrt_llm==1.2.0rc3
|
||||
|
||||
# Requires CUDA 13.0.0, TensorRT 10.13.2, Python 3.10-3.12
|
||||
```
|
||||
|
||||
### Basic inference
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
# Initialize model
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
|
||||
|
||||
# Configure sampling
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
)
|
||||
|
||||
# Generate
|
||||
prompts = ["Explain quantum computing"]
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
print(output.text)
|
||||
```
|
||||
|
||||
### Serving with trtllm-serve
|
||||
|
||||
```bash
|
||||
# Start server (automatic model download and compilation)
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--tp_size 4 \ # Tensor parallelism (4 GPUs)
|
||||
--max_batch_size 256 \
|
||||
--max_num_tokens 4096
|
||||
|
||||
# Client request
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
## Key features
|
||||
|
||||
### Performance optimizations
|
||||
- **In-flight batching**: Dynamic batching during generation
|
||||
- **Paged KV cache**: Efficient memory management
|
||||
- **Flash Attention**: Optimized attention kernels
|
||||
- **Quantization**: FP8, INT4, FP4 for 2-4× faster inference
|
||||
- **CUDA graphs**: Reduced kernel launch overhead
|
||||
|
||||
### Parallelism
|
||||
- **Tensor parallelism (TP)**: Split model across GPUs
|
||||
- **Pipeline parallelism (PP)**: Layer-wise distribution
|
||||
- **Expert parallelism**: For Mixture-of-Experts models
|
||||
- **Multi-node**: Scale beyond single machine
|
||||
|
||||
### Advanced features
|
||||
- **Speculative decoding**: Faster generation with draft models
|
||||
- **LoRA serving**: Efficient multi-adapter deployment
|
||||
- **Disaggregated serving**: Separate prefill and generation
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Quantized model (FP8)
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
# Load FP8 quantized model (2× faster, 50% memory)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
dtype="fp8",
|
||||
max_num_tokens=8192
|
||||
)
|
||||
|
||||
# Inference same as before
|
||||
outputs = llm.generate(["Summarize this article..."])
|
||||
```
|
||||
|
||||
### Multi-GPU deployment
|
||||
|
||||
```python
|
||||
# Tensor parallelism across 8 GPUs
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
tensor_parallel_size=8,
|
||||
dtype="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
### Batch inference
|
||||
|
||||
```python
|
||||
# Process 100 prompts efficiently
|
||||
prompts = [f"Question {i}: ..." for i in range(100)]
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params=SamplingParams(max_tokens=200)
|
||||
)
|
||||
|
||||
# Automatic in-flight batching for maximum throughput
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
**Meta Llama 3-8B** (H100 GPU):
|
||||
- Throughput: 24,000 tokens/sec
|
||||
- Latency: ~10ms per token
|
||||
- vs PyTorch: **100× faster**
|
||||
|
||||
**Llama 3-70B** (8× A100 80GB):
|
||||
- FP8 quantization: 2× faster than FP16
|
||||
- Memory: 50% reduction with FP8
|
||||
|
||||
## Supported models
|
||||
|
||||
- **LLaMA family**: Llama 2, Llama 3, CodeLlama
|
||||
- **GPT family**: GPT-2, GPT-J, GPT-NeoX
|
||||
- **Qwen**: Qwen, Qwen2, QwQ
|
||||
- **DeepSeek**: DeepSeek-V2, DeepSeek-V3
|
||||
- **Mixtral**: Mixtral-8x7B, Mixtral-8x22B
|
||||
- **Vision**: LLaVA, Phi-3-vision
|
||||
- **100+ models** on HuggingFace
|
||||
|
||||
## References
|
||||
|
||||
- **[Optimization Guide](references/optimization.md)** - Quantization, batching, KV cache tuning
|
||||
- **[Multi-GPU Setup](references/multi-gpu.md)** - Tensor/pipeline parallelism, multi-node
|
||||
- **[Serving Guide](references/serving.md)** - Production deployment, monitoring, autoscaling
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://nvidia.github.io/TensorRT-LLM/
|
||||
- **GitHub**: https://github.com/NVIDIA/TensorRT-LLM
|
||||
- **Models**: https://huggingface.co/models?library=tensorrt_llm
|
||||
|
||||
|
||||
298
skills/mlops/tensorrt-llm/references/multi-gpu.md
Normal file
298
skills/mlops/tensorrt-llm/references/multi-gpu.md
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
# Multi-GPU Deployment Guide
|
||||
|
||||
Comprehensive guide to scaling TensorRT-LLM across multiple GPUs and nodes.
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Tensor Parallelism (TP)
|
||||
|
||||
**What it does**: Splits model layers across GPUs horizontally.
|
||||
|
||||
**Use case**:
|
||||
- Model fits in total GPU memory but not single GPU
|
||||
- Need low latency (single forward pass)
|
||||
- GPUs on same node (NVLink required for best performance)
|
||||
|
||||
**Example** (Llama 3-70B on 4× A100):
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
tensor_parallel_size=4, # Split across 4 GPUs
|
||||
dtype="fp16"
|
||||
)
|
||||
|
||||
# Model automatically sharded across GPUs
|
||||
# Single forward pass, low latency
|
||||
```
|
||||
|
||||
**Performance**:
|
||||
- Latency: ~Same as single GPU
|
||||
- Throughput: 4× higher (4 GPUs)
|
||||
- Communication: High (activations synced every layer)
|
||||
|
||||
### Pipeline Parallelism (PP)
|
||||
|
||||
**What it does**: Splits model layers across GPUs vertically (layer-wise).
|
||||
|
||||
**Use case**:
|
||||
- Very large models (175B+)
|
||||
- Can tolerate higher latency
|
||||
- GPUs across multiple nodes
|
||||
|
||||
**Example** (Llama 3-405B on 8× H100):
|
||||
```python
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
tensor_parallel_size=4, # TP=4 within nodes
|
||||
pipeline_parallel_size=2, # PP=2 across nodes
|
||||
dtype="fp8"
|
||||
)
|
||||
|
||||
# Total: 8 GPUs (4×2)
|
||||
# Layers 0-40: Node 1 (4 GPUs with TP)
|
||||
# Layers 41-80: Node 2 (4 GPUs with TP)
|
||||
```
|
||||
|
||||
**Performance**:
|
||||
- Latency: Higher (sequential through pipeline)
|
||||
- Throughput: High with micro-batching
|
||||
- Communication: Lower than TP
|
||||
|
||||
### Expert Parallelism (EP)
|
||||
|
||||
**What it does**: Distributes MoE experts across GPUs.
|
||||
|
||||
**Use case**: Mixture-of-Experts models (Mixtral, DeepSeek-V2)
|
||||
|
||||
**Example** (Mixtral-8x22B on 8× A100):
|
||||
```python
|
||||
llm = LLM(
|
||||
model="mistralai/Mixtral-8x22B",
|
||||
tensor_parallel_size=4,
|
||||
expert_parallel_size=2, # Distribute 8 experts across 2 groups
|
||||
dtype="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Small model (7-13B) - Single GPU
|
||||
|
||||
```python
|
||||
# Llama 3-8B on 1× A100 80GB
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
dtype="fp16" # or fp8 for H100
|
||||
)
|
||||
```
|
||||
|
||||
**Resources**:
|
||||
- GPU: 1× A100 80GB
|
||||
- Memory: ~16GB model + 30GB KV cache
|
||||
- Throughput: 3,000-5,000 tokens/sec
|
||||
|
||||
### Medium model (70B) - Multi-GPU same node
|
||||
|
||||
```python
|
||||
# Llama 3-70B on 4× A100 80GB (NVLink)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
tensor_parallel_size=4,
|
||||
dtype="fp8" # 70GB → 35GB per GPU
|
||||
)
|
||||
```
|
||||
|
||||
**Resources**:
|
||||
- GPU: 4× A100 80GB with NVLink
|
||||
- Memory: ~35GB per GPU (FP8)
|
||||
- Throughput: 10,000-15,000 tokens/sec
|
||||
- Latency: 15-20ms per token
|
||||
|
||||
### Large model (405B) - Multi-node
|
||||
|
||||
```python
|
||||
# Llama 3-405B on 2 nodes × 8 H100 = 16 GPUs
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
tensor_parallel_size=8, # TP within each node
|
||||
pipeline_parallel_size=2, # PP across 2 nodes
|
||||
dtype="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
**Resources**:
|
||||
- GPU: 2 nodes × 8 H100 80GB
|
||||
- Memory: ~25GB per GPU (FP8)
|
||||
- Throughput: 20,000-30,000 tokens/sec
|
||||
- Network: InfiniBand recommended
|
||||
|
||||
## Server Deployment
|
||||
|
||||
### Single-node multi-GPU
|
||||
|
||||
```bash
|
||||
# Llama 3-70B on 4 GPUs (automatic TP)
|
||||
trtllm-serve meta-llama/Meta-Llama-3-70B \
|
||||
--tp_size 4 \
|
||||
--max_batch_size 256 \
|
||||
--dtype fp8
|
||||
|
||||
# Listens on http://localhost:8000
|
||||
```
|
||||
|
||||
### Multi-node with Ray
|
||||
|
||||
```bash
|
||||
# Node 1 (head node)
|
||||
ray start --head --port=6379
|
||||
|
||||
# Node 2 (worker)
|
||||
ray start --address='node1:6379'
|
||||
|
||||
# Deploy across cluster
|
||||
trtllm-serve meta-llama/Meta-Llama-3-405B \
|
||||
--tp_size 8 \
|
||||
--pp_size 2 \
|
||||
--num_workers 2 \ # 2 nodes
|
||||
--dtype fp8
|
||||
```
|
||||
|
||||
### Kubernetes deployment
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: tensorrt-llm-llama3-70b
|
||||
spec:
|
||||
replicas: 1
|
||||
template:
|
||||
spec:
|
||||
containers:
|
||||
- name: trtllm
|
||||
image: nvidia/tensorrt_llm:latest
|
||||
command:
|
||||
- trtllm-serve
|
||||
- meta-llama/Meta-Llama-3-70B
|
||||
- --tp_size=4
|
||||
- --max_batch_size=256
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 4 # Request 4 GPUs
|
||||
```
|
||||
|
||||
## Parallelism Decision Tree
|
||||
|
||||
```
|
||||
Model size < 20GB?
|
||||
├─ YES: Single GPU (no parallelism)
|
||||
└─ NO: Model size < 80GB?
|
||||
├─ YES: TP=2 or TP=4 (same node)
|
||||
└─ NO: Model size < 320GB?
|
||||
├─ YES: TP=4 or TP=8 (same node, NVLink required)
|
||||
└─ NO: TP=8 + PP=2 (multi-node)
|
||||
```
|
||||
|
||||
## Communication Optimization
|
||||
|
||||
### NVLink vs PCIe
|
||||
|
||||
**NVLink** (DGX A100, HGX H100):
|
||||
- Bandwidth: 600 GB/s (A100), 900 GB/s (H100)
|
||||
- Ideal for TP (high communication)
|
||||
- **Recommended for all multi-GPU setups**
|
||||
|
||||
**PCIe**:
|
||||
- Bandwidth: 64 GB/s (PCIe 4.0 x16)
|
||||
- 10× slower than NVLink
|
||||
- Avoid TP, use PP instead
|
||||
|
||||
### InfiniBand for multi-node
|
||||
|
||||
**HDR InfiniBand** (200 Gb/s):
|
||||
- Required for multi-node TP or PP
|
||||
- Latency: <1μs
|
||||
- **Essential for 405B+ models**
|
||||
|
||||
## Monitoring Multi-GPU
|
||||
|
||||
```python
|
||||
# Monitor GPU utilization
|
||||
nvidia-smi dmon -s u
|
||||
|
||||
# Monitor memory
|
||||
nvidia-smi dmon -s m
|
||||
|
||||
# Monitor NVLink utilization
|
||||
nvidia-smi nvlink --status
|
||||
|
||||
# TensorRT-LLM built-in metrics
|
||||
curl http://localhost:8000/metrics
|
||||
```
|
||||
|
||||
**Key metrics**:
|
||||
- GPU utilization: Target 80-95%
|
||||
- Memory usage: Should be balanced across GPUs
|
||||
- NVLink traffic: High for TP, low for PP
|
||||
- Throughput: Tokens/sec across all GPUs
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Imbalanced GPU memory
|
||||
|
||||
**Symptom**: GPU 0 has 90% memory, GPU 3 has 40%
|
||||
|
||||
**Solutions**:
|
||||
- Verify TP/PP configuration
|
||||
- Check model sharding (should be equal)
|
||||
- Restart server to reset state
|
||||
|
||||
### Low NVLink utilization
|
||||
|
||||
**Symptom**: NVLink bandwidth <100 GB/s with TP=4
|
||||
|
||||
**Solutions**:
|
||||
- Verify NVLink topology: `nvidia-smi topo -m`
|
||||
- Check for PCIe fallback
|
||||
- Ensure GPUs are on same NVSwitch
|
||||
|
||||
### OOM with multi-GPU
|
||||
|
||||
**Solutions**:
|
||||
- Increase TP size (more GPUs)
|
||||
- Reduce batch size
|
||||
- Enable FP8 quantization
|
||||
- Use pipeline parallelism
|
||||
|
||||
## Performance Scaling
|
||||
|
||||
### TP Scaling (Llama 3-70B, FP8)
|
||||
|
||||
| GPUs | TP Size | Throughput | Latency | Efficiency |
|
||||
|------|---------|------------|---------|------------|
|
||||
| 1 | 1 | OOM | - | - |
|
||||
| 2 | 2 | 6,000 tok/s | 18ms | 85% |
|
||||
| 4 | 4 | 11,000 tok/s | 16ms | 78% |
|
||||
| 8 | 8 | 18,000 tok/s | 15ms | 64% |
|
||||
|
||||
**Note**: Efficiency drops with more GPUs due to communication overhead.
|
||||
|
||||
### PP Scaling (Llama 3-405B, FP8)
|
||||
|
||||
| Nodes | TP | PP | Total GPUs | Throughput |
|
||||
|-------|----|----|------------|------------|
|
||||
| 1 | 8 | 1 | 8 | OOM |
|
||||
| 2 | 8 | 2 | 16 | 25,000 tok/s |
|
||||
| 4 | 8 | 4 | 32 | 45,000 tok/s |
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Prefer TP over PP** when possible (lower latency)
|
||||
2. **Use NVLink** for all TP deployments
|
||||
3. **Use InfiniBand** for multi-node deployments
|
||||
4. **Start with smallest TP** that fits model in memory
|
||||
5. **Monitor GPU balance** - all GPUs should have similar utilization
|
||||
6. **Test with benchmark** before production
|
||||
7. **Use FP8** on H100 for 2× speedup
|
||||
242
skills/mlops/tensorrt-llm/references/optimization.md
Normal file
242
skills/mlops/tensorrt-llm/references/optimization.md
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
# TensorRT-LLM Optimization Guide
|
||||
|
||||
Comprehensive guide to optimizing LLM inference with TensorRT-LLM.
|
||||
|
||||
## Quantization
|
||||
|
||||
### FP8 Quantization (Recommended for H100)
|
||||
|
||||
**Benefits**:
|
||||
- 2× faster inference
|
||||
- 50% memory reduction
|
||||
- Minimal accuracy loss (<1% perplexity degradation)
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
# Automatic FP8 quantization
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
dtype="fp8",
|
||||
quantization="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
**Performance** (Llama 3-70B on 8× H100):
|
||||
- FP16: 5,000 tokens/sec
|
||||
- FP8: **10,000 tokens/sec** (2× speedup)
|
||||
- Memory: 140GB → 70GB
|
||||
|
||||
### INT4 Quantization (Maximum compression)
|
||||
|
||||
**Benefits**:
|
||||
- 4× memory reduction
|
||||
- 3-4× faster inference
|
||||
- Fits larger models on same hardware
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
# INT4 with AWQ calibration
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
dtype="int4_awq",
|
||||
quantization="awq"
|
||||
)
|
||||
|
||||
# INT4 with GPTQ calibration
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
dtype="int4_gptq",
|
||||
quantization="gptq"
|
||||
)
|
||||
```
|
||||
|
||||
**Trade-offs**:
|
||||
- Accuracy: 1-3% perplexity increase
|
||||
- Speed: 3-4× faster than FP16
|
||||
- Use case: When memory is critical
|
||||
|
||||
## In-Flight Batching
|
||||
|
||||
**What it does**: Dynamically batches requests during generation instead of waiting for all sequences to finish.
|
||||
|
||||
**Configuration**:
|
||||
```python
|
||||
# Server configuration
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--max_batch_size 256 \ # Maximum concurrent sequences
|
||||
--max_num_tokens 4096 \ # Total tokens in batch
|
||||
--enable_chunked_context \ # Split long prompts
|
||||
--scheduler_policy max_utilization
|
||||
```
|
||||
|
||||
**Performance**:
|
||||
- Throughput: **4-8× higher** vs static batching
|
||||
- Latency: Lower P50/P99 for mixed workloads
|
||||
- GPU utilization: 80-95% vs 40-60%
|
||||
|
||||
## Paged KV Cache
|
||||
|
||||
**What it does**: Manages KV cache memory like OS manages virtual memory (paging).
|
||||
|
||||
**Benefits**:
|
||||
- 40-60% higher throughput
|
||||
- No memory fragmentation
|
||||
- Supports longer sequences
|
||||
|
||||
**Configuration**:
|
||||
```python
|
||||
# Automatic paged KV cache (default)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
kv_cache_free_gpu_mem_fraction=0.9, # Use 90% GPU mem for cache
|
||||
enable_prefix_caching=True # Cache common prefixes
|
||||
)
|
||||
```
|
||||
|
||||
## Speculative Decoding
|
||||
|
||||
**What it does**: Uses small draft model to predict multiple tokens, verified by target model in parallel.
|
||||
|
||||
**Speedup**: 2-3× faster for long generations
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
# Target model (Llama 3-70B)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
speculative_model="meta-llama/Meta-Llama-3-8B", # Draft model
|
||||
num_speculative_tokens=5 # Tokens to predict ahead
|
||||
)
|
||||
|
||||
# Same API, 2-3× faster
|
||||
outputs = llm.generate(prompts)
|
||||
```
|
||||
|
||||
**Best models for drafting**:
|
||||
- Target: Llama 3-70B → Draft: Llama 3-8B
|
||||
- Target: Qwen2-72B → Draft: Qwen2-7B
|
||||
- Same family, 8-10× smaller
|
||||
|
||||
## CUDA Graphs
|
||||
|
||||
**What it does**: Reduces kernel launch overhead by recording GPU operations.
|
||||
|
||||
**Benefits**:
|
||||
- 10-20% lower latency
|
||||
- More stable P99 latency
|
||||
- Better for small batch sizes
|
||||
|
||||
**Configuration** (automatic by default):
|
||||
```python
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
enable_cuda_graph=True, # Default: True
|
||||
cuda_graph_cache_size=2 # Cache 2 graph variants
|
||||
)
|
||||
```
|
||||
|
||||
## Chunked Context
|
||||
|
||||
**What it does**: Splits long prompts into chunks to reduce memory spikes.
|
||||
|
||||
**Use case**: Prompts >8K tokens with limited GPU memory
|
||||
|
||||
**Configuration**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--max_num_tokens 4096 \
|
||||
--enable_chunked_context \
|
||||
--max_chunked_prefill_length 2048 # Process 2K tokens at a time
|
||||
```
|
||||
|
||||
## Overlap Scheduling
|
||||
|
||||
**What it does**: Overlaps compute and memory operations.
|
||||
|
||||
**Benefits**:
|
||||
- 15-25% higher throughput
|
||||
- Better GPU utilization
|
||||
- Default in v1.2.0+
|
||||
|
||||
**No configuration needed** - enabled automatically.
|
||||
|
||||
## Quantization Comparison Table
|
||||
|
||||
| Method | Memory | Speed | Accuracy | Use Case |
|
||||
|--------|--------|-------|----------|----------|
|
||||
| FP16 | 1× (baseline) | 1× | Best | High accuracy needed |
|
||||
| FP8 | 0.5× | 2× | -0.5% ppl | **H100 default** |
|
||||
| INT4 AWQ | 0.25× | 3-4× | -1.5% ppl | Memory critical |
|
||||
| INT4 GPTQ | 0.25× | 3-4× | -2% ppl | Maximum speed |
|
||||
|
||||
## Tuning Workflow
|
||||
|
||||
1. **Start with defaults**:
|
||||
```python
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-70B")
|
||||
```
|
||||
|
||||
2. **Enable FP8** (if H100):
|
||||
```python
|
||||
llm = LLM(model="...", dtype="fp8")
|
||||
```
|
||||
|
||||
3. **Tune batch size**:
|
||||
```python
|
||||
# Increase until OOM, then reduce 20%
|
||||
trtllm-serve ... --max_batch_size 256
|
||||
```
|
||||
|
||||
4. **Enable chunked context** (if long prompts):
|
||||
```bash
|
||||
--enable_chunked_context --max_chunked_prefill_length 2048
|
||||
```
|
||||
|
||||
5. **Try speculative decoding** (if latency critical):
|
||||
```python
|
||||
llm = LLM(model="...", speculative_model="...")
|
||||
```
|
||||
|
||||
## Benchmarking
|
||||
|
||||
```bash
|
||||
# Install benchmark tool
|
||||
pip install tensorrt_llm[benchmark]
|
||||
|
||||
# Run benchmark
|
||||
python benchmarks/python/benchmark.py \
|
||||
--model meta-llama/Meta-Llama-3-8B \
|
||||
--batch_size 64 \
|
||||
--input_len 128 \
|
||||
--output_len 256 \
|
||||
--dtype fp8
|
||||
```
|
||||
|
||||
**Metrics to track**:
|
||||
- Throughput (tokens/sec)
|
||||
- Latency P50/P90/P99 (ms)
|
||||
- GPU memory usage (GB)
|
||||
- GPU utilization (%)
|
||||
|
||||
## Common Issues
|
||||
|
||||
**OOM errors**:
|
||||
- Reduce `max_batch_size`
|
||||
- Reduce `max_num_tokens`
|
||||
- Enable INT4 quantization
|
||||
- Increase `tensor_parallel_size`
|
||||
|
||||
**Low throughput**:
|
||||
- Increase `max_batch_size`
|
||||
- Enable in-flight batching
|
||||
- Verify CUDA graphs enabled
|
||||
- Check GPU utilization
|
||||
|
||||
**High latency**:
|
||||
- Try speculative decoding
|
||||
- Reduce `max_batch_size` (less queueing)
|
||||
- Use FP8 instead of FP16
|
||||
470
skills/mlops/tensorrt-llm/references/serving.md
Normal file
470
skills/mlops/tensorrt-llm/references/serving.md
Normal file
|
|
@ -0,0 +1,470 @@
|
|||
# Production Serving Guide
|
||||
|
||||
Comprehensive guide to deploying TensorRT-LLM in production environments.
|
||||
|
||||
## Server Modes
|
||||
|
||||
### trtllm-serve (Recommended)
|
||||
|
||||
**Features**:
|
||||
- OpenAI-compatible API
|
||||
- Automatic model download and compilation
|
||||
- Built-in load balancing
|
||||
- Prometheus metrics
|
||||
- Health checks
|
||||
|
||||
**Basic usage**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--tp_size 1 \
|
||||
--max_batch_size 256 \
|
||||
--port 8000
|
||||
```
|
||||
|
||||
**Advanced configuration**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-70B \
|
||||
--tp_size 4 \
|
||||
--dtype fp8 \
|
||||
--max_batch_size 256 \
|
||||
--max_num_tokens 4096 \
|
||||
--enable_chunked_context \
|
||||
--scheduler_policy max_utilization \
|
||||
--port 8000 \
|
||||
--api_key $API_KEY # Optional authentication
|
||||
```
|
||||
|
||||
### Python LLM API (For embedding)
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
class LLMService:
|
||||
def __init__(self):
|
||||
self.llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
dtype="fp8"
|
||||
)
|
||||
|
||||
def generate(self, prompt, max_tokens=100):
|
||||
from tensorrt_llm import SamplingParams
|
||||
|
||||
params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.7
|
||||
)
|
||||
outputs = self.llm.generate([prompt], params)
|
||||
return outputs[0].text
|
||||
|
||||
# Use in FastAPI, Flask, etc
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
service = LLMService()
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt: str):
|
||||
return {"response": service.generate(prompt)}
|
||||
```
|
||||
|
||||
## OpenAI-Compatible API
|
||||
|
||||
### Chat Completions
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Explain quantum computing"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
**Response**:
|
||||
```json
|
||||
{
|
||||
"id": "chat-abc123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Quantum computing is..."
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 25,
|
||||
"completion_tokens": 150,
|
||||
"total_tokens": 175
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Streaming
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"messages": [{"role": "user", "content": "Count to 10"}],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
**Response** (SSE stream):
|
||||
```
|
||||
data: {"choices":[{"delta":{"content":"1"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":", 2"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":", 3"}}]}
|
||||
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
### Completions
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"prompt": "The capital of France is",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0
|
||||
}'
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Prometheus Metrics
|
||||
|
||||
**Enable metrics**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--enable_metrics \
|
||||
--metrics_port 9090
|
||||
```
|
||||
|
||||
**Key metrics**:
|
||||
```bash
|
||||
# Scrape metrics
|
||||
curl http://localhost:9090/metrics
|
||||
|
||||
# Important metrics:
|
||||
# - trtllm_request_success_total - Total successful requests
|
||||
# - trtllm_request_latency_seconds - Request latency histogram
|
||||
# - trtllm_tokens_generated_total - Total tokens generated
|
||||
# - trtllm_active_requests - Current active requests
|
||||
# - trtllm_queue_size - Requests waiting in queue
|
||||
# - trtllm_gpu_memory_usage_bytes - GPU memory usage
|
||||
# - trtllm_kv_cache_usage_ratio - KV cache utilization
|
||||
```
|
||||
|
||||
### Health Checks
|
||||
|
||||
```bash
|
||||
# Readiness probe
|
||||
curl http://localhost:8000/health/ready
|
||||
|
||||
# Liveness probe
|
||||
curl http://localhost:8000/health/live
|
||||
|
||||
# Model info
|
||||
curl http://localhost:8000/v1/models
|
||||
```
|
||||
|
||||
**Kubernetes probes**:
|
||||
```yaml
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health/live
|
||||
port: 8000
|
||||
initialDelaySeconds: 60
|
||||
periodSeconds: 10
|
||||
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health/ready
|
||||
port: 8000
|
||||
initialDelaySeconds: 30
|
||||
periodSeconds: 5
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
### Docker Deployment
|
||||
|
||||
**Dockerfile**:
|
||||
```dockerfile
|
||||
FROM nvidia/tensorrt_llm:latest
|
||||
|
||||
# Copy any custom configs
|
||||
COPY config.yaml /app/config.yaml
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 8000 9090
|
||||
|
||||
# Start server
|
||||
CMD ["trtllm-serve", "meta-llama/Meta-Llama-3-8B", \
|
||||
"--tp_size", "4", \
|
||||
"--dtype", "fp8", \
|
||||
"--max_batch_size", "256", \
|
||||
"--enable_metrics", \
|
||||
"--metrics_port", "9090"]
|
||||
```
|
||||
|
||||
**Run container**:
|
||||
```bash
|
||||
docker run --gpus all -p 8000:8000 -p 9090:9090 \
|
||||
tensorrt-llm:latest
|
||||
```
|
||||
|
||||
### Kubernetes Deployment
|
||||
|
||||
**Complete deployment**:
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: tensorrt-llm
|
||||
spec:
|
||||
replicas: 2 # Multiple replicas for HA
|
||||
selector:
|
||||
matchLabels:
|
||||
app: tensorrt-llm
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: tensorrt-llm
|
||||
spec:
|
||||
containers:
|
||||
- name: trtllm
|
||||
image: nvidia/tensorrt_llm:latest
|
||||
command:
|
||||
- trtllm-serve
|
||||
- meta-llama/Meta-Llama-3-70B
|
||||
- --tp_size=4
|
||||
- --dtype=fp8
|
||||
- --max_batch_size=256
|
||||
- --enable_metrics
|
||||
ports:
|
||||
- containerPort: 8000
|
||||
name: http
|
||||
- containerPort: 9090
|
||||
name: metrics
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 4
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health/live
|
||||
port: 8000
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health/ready
|
||||
port: 8000
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: tensorrt-llm
|
||||
spec:
|
||||
selector:
|
||||
app: tensorrt-llm
|
||||
ports:
|
||||
- name: http
|
||||
port: 80
|
||||
targetPort: 8000
|
||||
- name: metrics
|
||||
port: 9090
|
||||
targetPort: 9090
|
||||
type: LoadBalancer
|
||||
```
|
||||
|
||||
### Load Balancing
|
||||
|
||||
**NGINX configuration**:
|
||||
```nginx
|
||||
upstream tensorrt_llm {
|
||||
least_conn; # Route to least busy server
|
||||
server trtllm-1:8000 max_fails=3 fail_timeout=30s;
|
||||
server trtllm-2:8000 max_fails=3 fail_timeout=30s;
|
||||
server trtllm-3:8000 max_fails=3 fail_timeout=30s;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
location / {
|
||||
proxy_pass http://tensorrt_llm;
|
||||
proxy_read_timeout 300s; # Long timeout for slow generations
|
||||
proxy_connect_timeout 10s;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Autoscaling
|
||||
|
||||
### Horizontal Pod Autoscaler (HPA)
|
||||
|
||||
```yaml
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: tensorrt-llm-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: tensorrt-llm
|
||||
minReplicas: 2
|
||||
maxReplicas: 10
|
||||
metrics:
|
||||
- type: Pods
|
||||
pods:
|
||||
metric:
|
||||
name: trtllm_active_requests
|
||||
target:
|
||||
type: AverageValue
|
||||
averageValue: "50" # Scale when avg >50 active requests
|
||||
```
|
||||
|
||||
### Custom Metrics
|
||||
|
||||
```yaml
|
||||
# Scale based on queue size
|
||||
- type: Pods
|
||||
pods:
|
||||
metric:
|
||||
name: trtllm_queue_size
|
||||
target:
|
||||
type: AverageValue
|
||||
averageValue: "10"
|
||||
```
|
||||
|
||||
## Cost Optimization
|
||||
|
||||
### GPU Selection
|
||||
|
||||
**A100 80GB** ($3-4/hour):
|
||||
- Use for: 70B models with FP8
|
||||
- Throughput: 10,000-15,000 tok/s (TP=4)
|
||||
- Cost per 1M tokens: $0.20-0.30
|
||||
|
||||
**H100 80GB** ($6-8/hour):
|
||||
- Use for: 70B models with FP8, 405B models
|
||||
- Throughput: 20,000-30,000 tok/s (TP=4)
|
||||
- Cost per 1M tokens: $0.15-0.25 (2× faster = lower cost)
|
||||
|
||||
**L4** ($0.50-1/hour):
|
||||
- Use for: 7-8B models
|
||||
- Throughput: 1,000-2,000 tok/s
|
||||
- Cost per 1M tokens: $0.25-0.50
|
||||
|
||||
### Batch Size Tuning
|
||||
|
||||
**Impact on cost**:
|
||||
- Batch size 1: 1,000 tok/s → $3/hour per 1M = $3/M tokens
|
||||
- Batch size 64: 5,000 tok/s → $3/hour per 5M = $0.60/M tokens
|
||||
- **5× cost reduction** with batching
|
||||
|
||||
**Recommendation**: Target batch size 32-128 for cost efficiency.
|
||||
|
||||
## Security
|
||||
|
||||
### API Authentication
|
||||
|
||||
```bash
|
||||
# Generate API key
|
||||
export API_KEY=$(openssl rand -hex 32)
|
||||
|
||||
# Start server with authentication
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--api_key $API_KEY
|
||||
|
||||
# Client request
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Authorization: Bearer $API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "...", "messages": [...]}'
|
||||
```
|
||||
|
||||
### Network Policies
|
||||
|
||||
```yaml
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: NetworkPolicy
|
||||
metadata:
|
||||
name: tensorrt-llm-policy
|
||||
spec:
|
||||
podSelector:
|
||||
matchLabels:
|
||||
app: tensorrt-llm
|
||||
policyTypes:
|
||||
- Ingress
|
||||
ingress:
|
||||
- from:
|
||||
- podSelector:
|
||||
matchLabels:
|
||||
app: api-gateway # Only allow from gateway
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 8000
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### High latency
|
||||
|
||||
**Diagnosis**:
|
||||
```bash
|
||||
# Check queue size
|
||||
curl http://localhost:9090/metrics | grep queue_size
|
||||
|
||||
# Check active requests
|
||||
curl http://localhost:9090/metrics | grep active_requests
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
- Scale horizontally (more replicas)
|
||||
- Increase batch size (if GPU underutilized)
|
||||
- Enable chunked context (if long prompts)
|
||||
- Use FP8 quantization
|
||||
|
||||
### OOM crashes
|
||||
|
||||
**Solutions**:
|
||||
- Reduce `max_batch_size`
|
||||
- Reduce `max_num_tokens`
|
||||
- Enable FP8 or INT4 quantization
|
||||
- Increase `tensor_parallel_size`
|
||||
|
||||
### Timeout errors
|
||||
|
||||
**NGINX config**:
|
||||
```nginx
|
||||
proxy_read_timeout 600s; # 10 minutes for very long generations
|
||||
proxy_send_timeout 600s;
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use FP8 on H100** for 2× speedup and 50% cost reduction
|
||||
2. **Monitor metrics** - Set up Prometheus + Grafana
|
||||
3. **Set readiness probes** - Prevent routing to unhealthy pods
|
||||
4. **Use load balancing** - Distribute load across replicas
|
||||
5. **Tune batch size** - Balance latency and throughput
|
||||
6. **Enable streaming** - Better UX for chat applications
|
||||
7. **Set up autoscaling** - Handle traffic spikes
|
||||
8. **Use persistent volumes** - Cache compiled models
|
||||
9. **Implement retries** - Handle transient failures
|
||||
10. **Monitor costs** - Track cost per token
|
||||
361
skills/mlops/torchtitan/SKILL.md
Normal file
361
skills/mlops/torchtitan/SKILL.md
Normal file
|
|
@ -0,0 +1,361 @@
|
|||
---
|
||||
name: distributed-llm-pretraining-torchtitan
|
||||
description: Provides PyTorch-native distributed LLM pretraining using torchtitan with 4D parallelism (FSDP2, TP, PP, CP). Use when pretraining Llama 3.1, DeepSeek V3, or custom models at scale from 8 to 512+ GPUs with Float8, torch.compile, and distributed checkpointing.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [torch>=2.6.0, torchtitan>=0.2.0, torchao>=0.5.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Model Architecture, Distributed Training, TorchTitan, FSDP2, Tensor Parallel, Pipeline Parallel, Context Parallel, Float8, Llama, Pretraining]
|
||||
|
||||
---
|
||||
|
||||
# TorchTitan - PyTorch Native Distributed LLM Pretraining
|
||||
|
||||
## Quick start
|
||||
|
||||
TorchTitan is PyTorch's official platform for large-scale LLM pretraining with composable 4D parallelism (FSDP2, TP, PP, CP), achieving 65%+ speedups over baselines on H100 GPUs.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
# From PyPI (stable)
|
||||
pip install torchtitan
|
||||
|
||||
# From source (latest features, requires PyTorch nightly)
|
||||
git clone https://github.com/pytorch/torchtitan
|
||||
cd torchtitan
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**Download tokenizer**:
|
||||
```bash
|
||||
# Get HF token from https://huggingface.co/settings/tokens
|
||||
python scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets tokenizer --hf_token=...
|
||||
```
|
||||
|
||||
**Start training on 8 GPUs**:
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Pretrain Llama 3.1 8B on single node
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Single Node Pretraining:
|
||||
- [ ] Step 1: Download tokenizer
|
||||
- [ ] Step 2: Configure training
|
||||
- [ ] Step 3: Launch training
|
||||
- [ ] Step 4: Monitor and checkpoint
|
||||
```
|
||||
|
||||
**Step 1: Download tokenizer**
|
||||
|
||||
```bash
|
||||
python scripts/download_hf_assets.py \
|
||||
--repo_id meta-llama/Llama-3.1-8B \
|
||||
--assets tokenizer \
|
||||
--hf_token=YOUR_HF_TOKEN
|
||||
```
|
||||
|
||||
**Step 2: Configure training**
|
||||
|
||||
Edit or create a TOML config file:
|
||||
|
||||
```toml
|
||||
# llama3_8b_custom.toml
|
||||
[job]
|
||||
dump_folder = "./outputs"
|
||||
description = "Llama 3.1 8B training"
|
||||
|
||||
[model]
|
||||
name = "llama3"
|
||||
flavor = "8B"
|
||||
hf_assets_path = "./assets/hf/Llama-3.1-8B"
|
||||
|
||||
[optimizer]
|
||||
name = "AdamW"
|
||||
lr = 3e-4
|
||||
|
||||
[lr_scheduler]
|
||||
warmup_steps = 200
|
||||
|
||||
[training]
|
||||
local_batch_size = 2
|
||||
seq_len = 8192
|
||||
max_norm = 1.0
|
||||
steps = 1000
|
||||
dataset = "c4"
|
||||
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = -1 # Use all GPUs for FSDP
|
||||
|
||||
[activation_checkpoint]
|
||||
mode = "selective"
|
||||
selective_ac_option = "op"
|
||||
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
```
|
||||
|
||||
**Step 3: Launch training**
|
||||
|
||||
```bash
|
||||
# 8 GPUs on single node
|
||||
CONFIG_FILE="./llama3_8b_custom.toml" ./run_train.sh
|
||||
|
||||
# Or explicitly with torchrun
|
||||
torchrun --nproc_per_node=8 \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_8b_custom.toml
|
||||
```
|
||||
|
||||
**Step 4: Monitor and checkpoint**
|
||||
|
||||
TensorBoard logs are saved to `./outputs/tb/`:
|
||||
```bash
|
||||
tensorboard --logdir ./outputs/tb
|
||||
```
|
||||
|
||||
### Workflow 2: Multi-node training with SLURM
|
||||
|
||||
```
|
||||
Multi-Node Training:
|
||||
- [ ] Step 1: Configure parallelism for scale
|
||||
- [ ] Step 2: Set up SLURM script
|
||||
- [ ] Step 3: Submit job
|
||||
- [ ] Step 4: Resume from checkpoint
|
||||
```
|
||||
|
||||
**Step 1: Configure parallelism for scale**
|
||||
|
||||
For 70B model on 256 GPUs (32 nodes):
|
||||
```toml
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = 32 # FSDP across 32 ranks
|
||||
tensor_parallel_degree = 8 # TP within node
|
||||
pipeline_parallel_degree = 1 # No PP for 70B
|
||||
context_parallel_degree = 1 # Increase for long sequences
|
||||
```
|
||||
|
||||
**Step 2: Set up SLURM script**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=llama70b
|
||||
#SBATCH --nodes=32
|
||||
#SBATCH --ntasks-per-node=8
|
||||
#SBATCH --gpus-per-node=8
|
||||
|
||||
srun torchrun \
|
||||
--nnodes=32 \
|
||||
--nproc_per_node=8 \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_70b.toml
|
||||
```
|
||||
|
||||
**Step 3: Submit job**
|
||||
|
||||
```bash
|
||||
sbatch multinode_trainer.slurm
|
||||
```
|
||||
|
||||
**Step 4: Resume from checkpoint**
|
||||
|
||||
Training auto-resumes if checkpoint exists in configured folder.
|
||||
|
||||
### Workflow 3: Enable Float8 training for H100s
|
||||
|
||||
Float8 provides 30-50% speedup on H100 GPUs.
|
||||
|
||||
```
|
||||
Float8 Training:
|
||||
- [ ] Step 1: Install torchao
|
||||
- [ ] Step 2: Configure Float8
|
||||
- [ ] Step 3: Launch with compile
|
||||
```
|
||||
|
||||
**Step 1: Install torchao**
|
||||
|
||||
```bash
|
||||
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
||||
```
|
||||
|
||||
**Step 2: Configure Float8**
|
||||
|
||||
Add to your TOML config:
|
||||
```toml
|
||||
[model]
|
||||
converters = ["quantize.linear.float8"]
|
||||
|
||||
[quantize.linear.float8]
|
||||
enable_fsdp_float8_all_gather = true
|
||||
precompute_float8_dynamic_scale_for_fsdp = true
|
||||
filter_fqns = ["output"] # Exclude output layer
|
||||
|
||||
[compile]
|
||||
enable = true
|
||||
components = ["model", "loss"]
|
||||
```
|
||||
|
||||
**Step 3: Launch with compile**
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
### Workflow 4: 4D parallelism for 405B models
|
||||
|
||||
```
|
||||
4D Parallelism (FSDP + TP + PP + CP):
|
||||
- [ ] Step 1: Create seed checkpoint
|
||||
- [ ] Step 2: Configure 4D parallelism
|
||||
- [ ] Step 3: Launch on 512 GPUs
|
||||
```
|
||||
|
||||
**Step 1: Create seed checkpoint**
|
||||
|
||||
Required for consistent initialization across PP stages:
|
||||
```bash
|
||||
NGPU=1 CONFIG_FILE=./llama3_405b.toml ./run_train.sh \
|
||||
--checkpoint.enable \
|
||||
--checkpoint.create_seed_checkpoint \
|
||||
--parallelism.data_parallel_shard_degree 1 \
|
||||
--parallelism.tensor_parallel_degree 1 \
|
||||
--parallelism.pipeline_parallel_degree 1
|
||||
```
|
||||
|
||||
**Step 2: Configure 4D parallelism**
|
||||
|
||||
```toml
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = 8 # FSDP
|
||||
tensor_parallel_degree = 8 # TP within node
|
||||
pipeline_parallel_degree = 8 # PP across nodes
|
||||
context_parallel_degree = 1 # CP for long sequences
|
||||
|
||||
[training]
|
||||
local_batch_size = 32
|
||||
seq_len = 8192
|
||||
```
|
||||
|
||||
**Step 3: Launch on 512 GPUs**
|
||||
|
||||
```bash
|
||||
# 64 nodes x 8 GPUs = 512 GPUs
|
||||
srun torchrun --nnodes=64 --nproc_per_node=8 \
|
||||
-m torchtitan.train \
|
||||
--job.config_file ./llama3_405b.toml
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use TorchTitan when:**
|
||||
- Pretraining LLMs from scratch (8B to 405B+)
|
||||
- Need PyTorch-native solution without third-party dependencies
|
||||
- Require composable 4D parallelism (FSDP2, TP, PP, CP)
|
||||
- Training on H100s with Float8 support
|
||||
- Want interoperable checkpoints with torchtune/HuggingFace
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Megatron-LM**: Maximum performance for NVIDIA-only deployments
|
||||
- **DeepSpeed**: Broader ZeRO optimization ecosystem, inference support
|
||||
- **Axolotl/TRL**: Fine-tuning rather than pretraining
|
||||
- **LitGPT**: Educational, smaller-scale training
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Out of memory on large models**
|
||||
|
||||
Enable activation checkpointing and reduce batch size:
|
||||
```toml
|
||||
[activation_checkpoint]
|
||||
mode = "full" # Instead of "selective"
|
||||
|
||||
[training]
|
||||
local_batch_size = 1
|
||||
```
|
||||
|
||||
Or use gradient accumulation:
|
||||
```toml
|
||||
[training]
|
||||
local_batch_size = 1
|
||||
global_batch_size = 32 # Accumulates gradients
|
||||
```
|
||||
|
||||
**Issue: TP causes high memory with async collectives**
|
||||
|
||||
Set environment variable:
|
||||
```bash
|
||||
export TORCH_NCCL_AVOID_RECORD_STREAMS=1
|
||||
```
|
||||
|
||||
**Issue: Float8 training not faster**
|
||||
|
||||
Float8 only benefits large GEMMs. Filter small layers:
|
||||
```toml
|
||||
[quantize.linear.float8]
|
||||
filter_fqns = ["attention.wk", "attention.wv", "output", "auto_filter_small_kn"]
|
||||
```
|
||||
|
||||
**Issue: Checkpoint loading fails after parallelism change**
|
||||
|
||||
Use DCP's resharding capability:
|
||||
```bash
|
||||
# Convert sharded checkpoint to single file
|
||||
python -m torch.distributed.checkpoint.format_utils \
|
||||
dcp_to_torch checkpoint/step-1000 checkpoint.pt
|
||||
```
|
||||
|
||||
**Issue: Pipeline parallelism initialization**
|
||||
|
||||
Create seed checkpoint first (see Workflow 4, Step 1).
|
||||
|
||||
## Supported models
|
||||
|
||||
| Model | Sizes | Status |
|
||||
|-------|-------|--------|
|
||||
| Llama 3.1 | 8B, 70B, 405B | Production |
|
||||
| Llama 4 | Various | Experimental |
|
||||
| DeepSeek V3 | 16B, 236B, 671B (MoE) | Experimental |
|
||||
| GPT-OSS | 20B, 120B (MoE) | Experimental |
|
||||
| Qwen 3 | Various | Experimental |
|
||||
| Flux | Diffusion | Experimental |
|
||||
|
||||
## Performance benchmarks (H100)
|
||||
|
||||
| Model | GPUs | Parallelism | TPS/GPU | Techniques |
|
||||
|-------|------|-------------|---------|------------|
|
||||
| Llama 8B | 8 | FSDP | 5,762 | Baseline |
|
||||
| Llama 8B | 8 | FSDP+compile+FP8 | 8,532 | +48% |
|
||||
| Llama 70B | 256 | FSDP+TP+AsyncTP | 876 | 2D parallel |
|
||||
| Llama 405B | 512 | FSDP+TP+PP | 128 | 3D parallel |
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**FSDP2 configuration**: See [references/fsdp.md](references/fsdp.md) for detailed FSDP2 vs FSDP1 comparison and ZeRO equivalents.
|
||||
|
||||
**Float8 training**: See [references/float8.md](references/float8.md) for tensorwise vs rowwise scaling recipes.
|
||||
|
||||
**Checkpointing**: See [references/checkpoint.md](references/checkpoint.md) for HuggingFace conversion and async checkpointing.
|
||||
|
||||
**Adding custom models**: See [references/custom-models.md](references/custom-models.md) for TrainSpec protocol.
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/pytorch/torchtitan
|
||||
- Paper: https://arxiv.org/abs/2410.06511
|
||||
- ICLR 2025: https://iclr.cc/virtual/2025/poster/29620
|
||||
- PyTorch Forum: https://discuss.pytorch.org/c/distributed/torchtitan/44
|
||||
|
||||
181
skills/mlops/torchtitan/references/checkpoint.md
Normal file
181
skills/mlops/torchtitan/references/checkpoint.md
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
# Checkpointing in TorchTitan
|
||||
|
||||
TorchTitan uses PyTorch Distributed Checkpoint (DCP) for fault-tolerant, interoperable checkpointing.
|
||||
|
||||
## Basic Configuration
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
```
|
||||
|
||||
## Save Model Only (Smaller Checkpoints)
|
||||
|
||||
Exclude optimizer state and training metadata:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
last_save_model_only = true
|
||||
export_dtype = "bfloat16" # Optional: export in lower precision
|
||||
```
|
||||
|
||||
## Excluding Keys from Loading
|
||||
|
||||
Partial checkpoint loading for modified settings:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
exclude_from_loading = ["data_loader", "lr_scheduler"]
|
||||
```
|
||||
|
||||
CLI equivalent:
|
||||
```bash
|
||||
--checkpoint.exclude_from_loading data_loader,lr_scheduler
|
||||
```
|
||||
|
||||
## Creating Seed Checkpoints
|
||||
|
||||
Required for Pipeline Parallelism to ensure consistent initialization:
|
||||
|
||||
```bash
|
||||
NGPU=1 CONFIG_FILE=<path_to_config> ./run_train.sh \
|
||||
--checkpoint.enable \
|
||||
--checkpoint.create_seed_checkpoint \
|
||||
--parallelism.data_parallel_replicate_degree 1 \
|
||||
--parallelism.data_parallel_shard_degree 1 \
|
||||
--parallelism.tensor_parallel_degree 1 \
|
||||
--parallelism.pipeline_parallel_degree 1 \
|
||||
--parallelism.context_parallel_degree 1 \
|
||||
--parallelism.expert_parallel_degree 1
|
||||
```
|
||||
|
||||
This initializes on single CPU for reproducible initialization across any GPU count.
|
||||
|
||||
## Async Checkpointing
|
||||
|
||||
Reduce checkpoint overhead with async writes:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
async_mode = "async" # Options: "disabled", "async", "async_with_pinned_mem"
|
||||
```
|
||||
|
||||
## HuggingFace Conversion
|
||||
|
||||
### During Training
|
||||
|
||||
Save directly in HuggingFace format:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
last_save_in_hf = true
|
||||
last_save_model_only = true
|
||||
```
|
||||
|
||||
Load from HuggingFace:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
initial_load_in_hf = true
|
||||
|
||||
[model]
|
||||
hf_assets_path = "./path/to/hf/checkpoint"
|
||||
```
|
||||
|
||||
### Offline Conversion
|
||||
|
||||
Convert without running training:
|
||||
|
||||
```bash
|
||||
# HuggingFace -> TorchTitan
|
||||
python ./scripts/checkpoint_conversion/convert_from_hf.py \
|
||||
<input_dir> <output_dir> \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
|
||||
# TorchTitan -> HuggingFace
|
||||
python ./scripts/checkpoint_conversion/convert_to_hf.py \
|
||||
<input_dir> <output_dir> \
|
||||
--hf_assets_path ./assets/hf/Llama3.1-8B \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
```
|
||||
|
||||
### Example
|
||||
|
||||
```bash
|
||||
python ./scripts/convert_from_hf.py \
|
||||
~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ \
|
||||
./initial_load_path/ \
|
||||
--model_name llama3 \
|
||||
--model_flavor 8B
|
||||
```
|
||||
|
||||
## Converting to Single .pt File
|
||||
|
||||
Convert DCP sharded checkpoint to single PyTorch file:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.checkpoint.format_utils \
|
||||
dcp_to_torch \
|
||||
torchtitan/outputs/checkpoint/step-1000 \
|
||||
checkpoint.pt
|
||||
```
|
||||
|
||||
## Checkpoint Structure
|
||||
|
||||
DCP saves sharded checkpoints that can be resharded for different parallelism configurations:
|
||||
|
||||
```
|
||||
checkpoint/
|
||||
├── step-500/
|
||||
│ ├── .metadata
|
||||
│ ├── __0_0.distcp
|
||||
│ ├── __0_1.distcp
|
||||
│ └── ...
|
||||
└── step-1000/
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Resume Training
|
||||
|
||||
Training auto-resumes from the latest checkpoint in the configured folder. To resume from a specific step:
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
load_step = 500 # Resume from step 500
|
||||
```
|
||||
|
||||
## Interoperability with TorchTune
|
||||
|
||||
Checkpoints saved with `last_save_model_only = true` can be loaded directly into [torchtune](https://github.com/pytorch/torchtune) for fine-tuning.
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
```toml
|
||||
[checkpoint]
|
||||
enable = true
|
||||
folder = "checkpoint"
|
||||
interval = 500
|
||||
load_step = -1 # -1 = latest, or specify step number
|
||||
last_save_model_only = true
|
||||
export_dtype = "bfloat16"
|
||||
async_mode = "async"
|
||||
exclude_from_loading = []
|
||||
last_save_in_hf = false
|
||||
initial_load_in_hf = false
|
||||
create_seed_checkpoint = false
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Large models**: Use `async_mode = "async"` to overlap checkpoint saves with training
|
||||
2. **Fine-tuning export**: Enable `last_save_model_only` and `export_dtype = "bfloat16"` for smaller files
|
||||
3. **Pipeline parallelism**: Always create seed checkpoint first
|
||||
4. **Debugging**: Save frequent checkpoints during development, reduce for production
|
||||
5. **HF interop**: Use conversion scripts for offline conversion, direct save/load for training workflows
|
||||
258
skills/mlops/torchtitan/references/custom-models.md
Normal file
258
skills/mlops/torchtitan/references/custom-models.md
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
# Adding Custom Models to TorchTitan
|
||||
|
||||
This guide explains how to add a new model to TorchTitan following the established patterns.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
torchtitan/models/your_model/
|
||||
├── model/
|
||||
│ ├── __init__.py
|
||||
│ ├── args.py # Model arguments
|
||||
│ ├── model.py # Model definition
|
||||
│ └── state_dict_adapter.py # HF conversion (optional)
|
||||
├── infra/
|
||||
│ ├── __init__.py
|
||||
│ ├── parallelize.py # TP, FSDP, compile application
|
||||
│ └── pipeline.py # PP application (optional)
|
||||
├── train_configs/
|
||||
│ ├── debug_model.toml
|
||||
│ └── your_model_XB.toml
|
||||
├── __init__.py # TrainSpec registration
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Step 1: Define Model Arguments
|
||||
|
||||
Inherit from `BaseModelArgs`:
|
||||
|
||||
```python
|
||||
# model/args.py
|
||||
from torchtitan.protocols.model import BaseModelArgs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class YourModelArgs(BaseModelArgs):
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
vocab_size: int = 128256
|
||||
|
||||
def get_nparams_and_flops(self, seq_len: int) -> tuple[int, int]:
|
||||
"""Return (num_params, flops_per_token) for throughput calculation."""
|
||||
nparams = self.vocab_size * self.dim + ... # Calculate params
|
||||
flops = 6 * nparams # Approximate: 6 * params for forward+backward
|
||||
return nparams, flops
|
||||
|
||||
def update_from_config(self, job_config) -> "YourModelArgs":
|
||||
"""Update args from training config."""
|
||||
# Override specific args from job_config if needed
|
||||
return self
|
||||
```
|
||||
|
||||
## Step 2: Define Model
|
||||
|
||||
Inherit from `ModelProtocol`:
|
||||
|
||||
```python
|
||||
# model/model.py
|
||||
import torch.nn as nn
|
||||
from torchtitan.protocols.model import ModelProtocol
|
||||
from .args import YourModelArgs
|
||||
|
||||
class YourModel(ModelProtocol):
|
||||
def __init__(self, args: YourModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
self.layers = nn.ModuleDict({
|
||||
str(i): TransformerBlock(args) for i in range(args.n_layers)
|
||||
})
|
||||
self.norm = RMSNorm(args.dim)
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
h = self.tok_embeddings(tokens)
|
||||
for layer in self.layers.values():
|
||||
h = layer(h)
|
||||
h = self.norm(h)
|
||||
return self.output(h)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights recursively."""
|
||||
for module in self.modules():
|
||||
if hasattr(module, 'init_weights') and module is not self:
|
||||
module.init_weights()
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=0.02)
|
||||
```
|
||||
|
||||
**Important guidelines**:
|
||||
- Write single-device model code (parallelism applied externally)
|
||||
- Use `nn.ModuleDict` for layers (preserves FQNs when deleting for PP)
|
||||
- Make input/output layers optional for PP compatibility
|
||||
- Define `init_weights()` recursively
|
||||
|
||||
## Step 3: Parallelize Function
|
||||
|
||||
```python
|
||||
# infra/parallelize.py
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
|
||||
def parallelize_your_model(
|
||||
model: YourModel,
|
||||
world_mesh: DeviceMesh,
|
||||
parallel_dims: ParallelDims,
|
||||
job_config: JobConfig,
|
||||
):
|
||||
# Apply in this order: TP -> AC -> compile -> FSDP
|
||||
|
||||
# 1. Tensor Parallelism
|
||||
if parallel_dims.tp_enabled:
|
||||
apply_tp(model, world_mesh["tp"], job_config)
|
||||
|
||||
# 2. Activation Checkpointing
|
||||
if job_config.activation_checkpoint.mode == "full":
|
||||
apply_ac(model, job_config)
|
||||
|
||||
# 3. torch.compile
|
||||
if job_config.compile.enable:
|
||||
model = torch.compile(model)
|
||||
|
||||
# 4. FSDP
|
||||
if parallel_dims.dp_enabled:
|
||||
apply_fsdp(model, world_mesh["dp"], job_config)
|
||||
|
||||
return model
|
||||
```
|
||||
|
||||
## Step 4: Create TrainSpec
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
from torchtitan.protocols.train_spec import TrainSpec, register_train_spec
|
||||
from .model.model import YourModel
|
||||
from .model.args import YourModelArgs
|
||||
from .infra.parallelize import parallelize_your_model
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"8B": YourModelArgs(dim=4096, n_layers=32, n_heads=32),
|
||||
"70B": YourModelArgs(dim=8192, n_layers=80, n_heads=64),
|
||||
}
|
||||
|
||||
def get_train_spec(flavor: str) -> TrainSpec:
|
||||
return TrainSpec(
|
||||
model_cls=YourModel,
|
||||
model_args=MODEL_CONFIGS[flavor],
|
||||
parallelize_fn=parallelize_your_model,
|
||||
pipeline_fn=None, # Or your_pipeline_fn for PP
|
||||
build_optimizer_fn=build_optimizer, # Reuse existing
|
||||
build_lr_scheduler_fn=build_lr_scheduler, # Reuse existing
|
||||
build_dataloader_fn=build_dataloader, # Reuse existing
|
||||
build_tokenizer_fn=build_tokenizer, # Reuse existing
|
||||
build_loss_fn=build_loss, # Reuse existing
|
||||
state_dict_adapter=None, # Or YourStateDictAdapter
|
||||
)
|
||||
|
||||
# Register so train.py can find it
|
||||
register_train_spec("your_model", get_train_spec)
|
||||
```
|
||||
|
||||
## Step 5: State Dict Adapter (Optional)
|
||||
|
||||
For HuggingFace checkpoint conversion:
|
||||
|
||||
```python
|
||||
# model/state_dict_adapter.py
|
||||
from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter
|
||||
|
||||
class YourStateDictAdapter(BaseStateDictAdapter):
|
||||
def to_hf(self, state_dict: dict) -> dict:
|
||||
"""Convert torchtitan state dict to HF format."""
|
||||
hf_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
hf_key = self._convert_key_to_hf(key)
|
||||
hf_state_dict[hf_key] = value
|
||||
return hf_state_dict
|
||||
|
||||
def from_hf(self, state_dict: dict) -> dict:
|
||||
"""Convert HF state dict to torchtitan format."""
|
||||
tt_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
tt_key = self._convert_key_from_hf(key)
|
||||
tt_state_dict[tt_key] = value
|
||||
return tt_state_dict
|
||||
```
|
||||
|
||||
## Step 6: Training Config
|
||||
|
||||
```toml
|
||||
# train_configs/your_model_8b.toml
|
||||
[job]
|
||||
dump_folder = "./outputs"
|
||||
description = "Your Model 8B training"
|
||||
|
||||
[model]
|
||||
name = "your_model"
|
||||
flavor = "8B"
|
||||
|
||||
[optimizer]
|
||||
name = "AdamW"
|
||||
lr = 3e-4
|
||||
|
||||
[training]
|
||||
local_batch_size = 2
|
||||
seq_len = 8192
|
||||
steps = 1000
|
||||
dataset = "c4"
|
||||
|
||||
[parallelism]
|
||||
data_parallel_shard_degree = -1
|
||||
tensor_parallel_degree = 1
|
||||
```
|
||||
|
||||
## Step 7: Register Model
|
||||
|
||||
Add to `torchtitan/models/__init__.py`:
|
||||
|
||||
```python
|
||||
from .your_model import get_train_spec as get_your_model_train_spec
|
||||
|
||||
MODEL_REGISTRY["your_model"] = get_your_model_train_spec
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Numerics Test
|
||||
|
||||
Compare output with HuggingFace implementation:
|
||||
|
||||
```python
|
||||
def test_numerics():
|
||||
# Load same checkpoint into both implementations
|
||||
tt_model = YourModel(args).load_checkpoint(...)
|
||||
hf_model = HFYourModel.from_pretrained(...)
|
||||
|
||||
# Compare outputs
|
||||
input_ids = torch.randint(0, vocab_size, (1, 128))
|
||||
tt_output = tt_model(input_ids)
|
||||
hf_output = hf_model(input_ids).logits
|
||||
|
||||
torch.testing.assert_close(tt_output, hf_output, atol=1e-4, rtol=1e-4)
|
||||
```
|
||||
|
||||
### Loss Convergence
|
||||
|
||||
Compare loss curves with verified baseline (see `docs/converging.md`).
|
||||
|
||||
### Performance Benchmark
|
||||
|
||||
Add benchmark config to `benchmarks/` folder.
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
1. **Readability over flexibility**: Don't over-abstract
|
||||
2. **Minimal model changes**: Parallelism applied externally
|
||||
3. **Clean, minimal codebase**: Reuse existing components where possible
|
||||
4. **Single-device semantics**: Model code should work on single GPU
|
||||
133
skills/mlops/torchtitan/references/float8.md
Normal file
133
skills/mlops/torchtitan/references/float8.md
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
# Float8 Training in TorchTitan
|
||||
|
||||
Float8 training provides substantial speedups for models where GEMMs are large enough that the FP8 tensorcore speedup outweighs dynamic quantization overhead.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
- NVIDIA H100 or newer GPUs (FP8 Tensor Cores)
|
||||
- Blackwell GPUs for MXFP8 training
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
|
||||
```
|
||||
|
||||
## Usage: Tensorwise Scaling
|
||||
|
||||
Standard Float8 with tensorwise dynamic scaling:
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.enable_fsdp_float8_all_gather \
|
||||
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
### Key Arguments
|
||||
|
||||
| Argument | Description |
|
||||
|----------|-------------|
|
||||
| `--model.converters="quantize.linear.float8"` | Swap `nn.Linear` with `Float8Linear` |
|
||||
| `--quantize.linear.float8.enable_fsdp_float8_all_gather` | Communicate in float8 to save bandwidth |
|
||||
| `--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp` | Single all-reduce for all AMAX/scales |
|
||||
| `--compile.enable` | Required - fuses float8 scaling/casting kernels |
|
||||
|
||||
## Usage: Rowwise Scaling
|
||||
|
||||
Higher accuracy than tensorwise scaling:
|
||||
|
||||
```bash
|
||||
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
|
||||
--model.converters="quantize.linear.float8" \
|
||||
--quantize.linear.float8.recipe_name rowwise \
|
||||
--compile.enable
|
||||
```
|
||||
|
||||
## Filtering Layers
|
||||
|
||||
Not all layers benefit from Float8. Filter small layers:
|
||||
|
||||
```bash
|
||||
--quantize.linear.float8.filter_fqns="attention.wk,attention.wv,output"
|
||||
```
|
||||
|
||||
### Auto-filtering
|
||||
|
||||
Automatically skip layers too small to benefit:
|
||||
|
||||
```bash
|
||||
--quantize.linear.float8.filter_fqns="auto_filter_small_kn"
|
||||
```
|
||||
|
||||
Thresholds based on H100 microbenchmarks where speedup > overhead.
|
||||
|
||||
## TOML Configuration
|
||||
|
||||
```toml
|
||||
[model]
|
||||
converters = ["quantize.linear.float8"]
|
||||
|
||||
[quantize.linear.float8]
|
||||
enable_fsdp_float8_all_gather = true
|
||||
precompute_float8_dynamic_scale_for_fsdp = true
|
||||
filter_fqns = ["output", "auto_filter_small_kn"]
|
||||
|
||||
[compile]
|
||||
enable = true
|
||||
components = ["model", "loss"]
|
||||
```
|
||||
|
||||
## How Float8 Works with Distributed Training
|
||||
|
||||
### Single Device
|
||||
|
||||
Cast input and weight to float8 inside forward before calling `torch._scaled_mm`:
|
||||
|
||||
```python
|
||||
# Float8 matmul requires scales
|
||||
torch._scaled_mm(input_fp8, weight_fp8, scale_a=scale_input, scale_b=scale_weight)
|
||||
```
|
||||
|
||||
### FSDP + Float8
|
||||
|
||||
1. Cast sharded high-precision weights (1/N per rank) to float8
|
||||
2. Perform float8 all-gather (saves bandwidth vs bf16/fp32)
|
||||
3. Communicate `max(abs)` across ranks for scale computation
|
||||
4. At forward start, have unsharded float8 weights ready
|
||||
|
||||
**Net benefit**: Float8 all-gather + amax communication can beat bf16/fp32 all-gather, depending on world size and message size.
|
||||
|
||||
### TP + Float8
|
||||
|
||||
- **Input**: Cast sharded input to float8, all-gather in float8
|
||||
- **Weights**: Communicate `max(abs)` for sharded weights
|
||||
- **Matmul**: Float8 input (unsharded) x float8 weight (sharded) with global scales
|
||||
|
||||
## Scaling Strategies
|
||||
|
||||
| Strategy | Status | Description |
|
||||
|----------|--------|-------------|
|
||||
| Tensorwise dynamic | Stable | Single scale per tensor |
|
||||
| Rowwise dynamic | Alpha | Scale per row, higher accuracy |
|
||||
|
||||
## Performance Gains
|
||||
|
||||
From benchmarks on H100:
|
||||
|
||||
| Configuration | TPS/GPU | vs Baseline |
|
||||
|---------------|---------|-------------|
|
||||
| FSDP only | 5,762 | - |
|
||||
| FSDP + compile | 6,667 | +16% |
|
||||
| FSDP + compile + Float8 | 8,532 | +48% |
|
||||
|
||||
## Determining Float8 Benefit
|
||||
|
||||
Check [torchao microbenchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) for forward+backward pass speedups on "layer norm => linear => sigmoid" for different M,N,K sizes.
|
||||
|
||||
Rule of thumb: GEMMs with K,N > 4096 typically benefit from Float8.
|
||||
|
||||
## MXFP8 Training (Blackwell)
|
||||
|
||||
For NVIDIA Blackwell GPUs, TorchTitan supports MXFP8 (Microscaling FP8) for both dense and MoE models. See [docs/mxfp8.md](https://github.com/pytorch/torchtitan/blob/main/docs/mxfp8.md) for details.
|
||||
126
skills/mlops/torchtitan/references/fsdp.md
Normal file
126
skills/mlops/torchtitan/references/fsdp.md
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
# FSDP2 in TorchTitan
|
||||
|
||||
## Why FSDP2?
|
||||
|
||||
FSDP2 is a rewrite of PyTorch's Fully Sharded Data Parallel (FSDP) API, removing the `FlatParameter` abstraction for better composability and simpler implementation.
|
||||
|
||||
### Key improvements over FSDP1
|
||||
|
||||
- **DTensor-based sharding**: Sharded parameters are `DTensor`s on dim-0, enabling easy manipulation and communication-free sharded state dicts
|
||||
- **Better memory management**: Deterministic and lower GPU memory (7% reduction) by avoiding `recordStream`
|
||||
- **Simplified API**: Fewer arguments, no wrapper class
|
||||
|
||||
### Performance
|
||||
|
||||
On Llama-7B with 8x H100s, FSDP2 achieves higher MFU with 7% lower peak memory than FSDP1, matching the same loss curve.
|
||||
|
||||
## API Reference
|
||||
|
||||
```python
|
||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, OffloadPolicy
|
||||
|
||||
@contract(state_cls=FSDPState)
|
||||
def fully_shard(
|
||||
module: nn.Module,
|
||||
*,
|
||||
mesh: Optional[DeviceMesh] = None,
|
||||
reshard_after_forward: Union[bool, int] = True,
|
||||
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
||||
offload_policy: OffloadPolicy = OffloadPolicy(),
|
||||
) -> nn.Module:
|
||||
```
|
||||
|
||||
## Sharding Strategies (ZeRO Equivalents)
|
||||
|
||||
| FSDP2 Configuration | FSDP1 Equivalent | DeepSpeed |
|
||||
|---------------------|------------------|-----------|
|
||||
| 1D mesh + `reshard_after_forward=True` | FULL_SHARD | ZeRO-3 |
|
||||
| 1D mesh + `reshard_after_forward=False` | SHARD_GRAD_OP | ZeRO-2 |
|
||||
| 2D mesh + `reshard_after_forward=True` | HYBRID_SHARD | MiCS |
|
||||
| 1D/2D mesh + `reshard_after_forward=8` (int) | - | ZeRO++ hpZ |
|
||||
|
||||
## Meta-Device Initialization
|
||||
|
||||
FSDP2 supports materializing tensors onto GPU _after_ sharding:
|
||||
|
||||
```python
|
||||
# Initialize on meta device (no memory)
|
||||
with torch.device("meta"):
|
||||
model = Transformer()
|
||||
|
||||
# Apply FSDP2 sharding
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
fully_shard(module)
|
||||
fully_shard(model)
|
||||
|
||||
# Parameters still on meta device
|
||||
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
||||
assert tensor.device == torch.device("meta")
|
||||
|
||||
# Allocate sharded parameters on GPU
|
||||
model.to_empty(device="cuda")
|
||||
|
||||
# Initialize weights
|
||||
model.init_weights()
|
||||
```
|
||||
|
||||
## State Dict Differences
|
||||
|
||||
| Operation | FSDP1 | FSDP2 |
|
||||
|-----------|-------|-------|
|
||||
| `model.state_dict()` | Full state dict | Sharded state dict (no communication) |
|
||||
| `optim.state_dict()` | Local state dict | Sharded state dict (no communication) |
|
||||
| `summon_full_params()` | Supported | Use `DTensor` APIs like `full_tensor()` |
|
||||
| Gradient clipping | `FSDP.clip_grad_norm_()` | `nn.utils.clip_grad_norm_()` |
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
```python
|
||||
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.float32,
|
||||
output_dtype=torch.bfloat16,
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
|
||||
fully_shard(model, mp_policy=mp_policy)
|
||||
```
|
||||
|
||||
## HSDP (Hybrid Sharded Data Parallel)
|
||||
|
||||
For 2D parallelism with replication + sharding:
|
||||
|
||||
```python
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
# Replicate across 4 groups, shard within 8 GPUs each
|
||||
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("replicate", "shard"))
|
||||
|
||||
fully_shard(model, mesh=mesh)
|
||||
```
|
||||
|
||||
## Configuration in TorchTitan
|
||||
|
||||
```toml
|
||||
[parallelism]
|
||||
# FSDP sharding degree (-1 = auto, use all available GPUs)
|
||||
data_parallel_shard_degree = -1
|
||||
|
||||
# HSDP replication degree (1 = pure FSDP, >1 = HSDP)
|
||||
data_parallel_replicate_degree = 1
|
||||
```
|
||||
|
||||
## Removed Arguments from FSDP1
|
||||
|
||||
These FSDP1 arguments are no longer needed:
|
||||
|
||||
- `auto_wrap_policy`: Apply `fully_shard` directly to modules
|
||||
- `backward_prefetch`: Always uses BACKWARD_PRE
|
||||
- `param_init_fn`: Use meta-device initialization
|
||||
- `device_id`: Uses mesh's device automatically
|
||||
- `sync_module_states`: Not needed with DTensor
|
||||
- `limit_all_gathers`: New memory management doesn't need it
|
||||
- `use_orig_params`: Always true (no FlatParameter)
|
||||
458
skills/mlops/trl-fine-tuning/SKILL.md
Normal file
458
skills/mlops/trl-fine-tuning/SKILL.md
Normal file
|
|
@ -0,0 +1,458 @@
|
|||
---
|
||||
name: fine-tuning-with-trl
|
||||
description: Fine-tune LLMs using reinforcement learning with TRL - SFT for instruction tuning, DPO for preference alignment, PPO/GRPO for reward optimization, and reward model training. Use when need RLHF, align model with preferences, or train from human feedback. Works with HuggingFace Transformers.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [trl, transformers, datasets, peft, accelerate, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, TRL, Reinforcement Learning, Fine-Tuning, SFT, DPO, PPO, GRPO, RLHF, Preference Alignment, HuggingFace]
|
||||
|
||||
---
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
## Quick start
|
||||
|
||||
TRL provides post-training methods for aligning language models with human preferences.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install trl transformers datasets peft accelerate
|
||||
```
|
||||
|
||||
**Supervised Fine-Tuning** (instruction tuning):
|
||||
```python
|
||||
from trl import SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset, # Prompt-completion pairs
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**DPO** (align with preferences):
|
||||
```python
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
config = DPOConfig(output_dir="model-dpo", beta=0.1)
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=preference_dataset, # chosen/rejected pairs
|
||||
processing_class=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Full RLHF pipeline (SFT → Reward Model → PPO)
|
||||
|
||||
Complete pipeline from base model to human-aligned model.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
RLHF Training:
|
||||
- [ ] Step 1: Supervised fine-tuning (SFT)
|
||||
- [ ] Step 2: Train reward model
|
||||
- [ ] Step 3: PPO reinforcement learning
|
||||
- [ ] Step 4: Evaluate aligned model
|
||||
```
|
||||
|
||||
**Step 1: Supervised fine-tuning**
|
||||
|
||||
Train base model on instruction-following data:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
# Load instruction dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure training
|
||||
training_args = SFTConfig(
|
||||
output_dir="Qwen2.5-0.5B-SFT",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=2e-5,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**Step 2: Train reward model**
|
||||
|
||||
Train model to predict human preferences:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
# Load SFT model as base
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen2.5-0.5B-SFT",
|
||||
num_labels=1 # Single reward score
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-0.5B-SFT")
|
||||
|
||||
# Load preference data (chosen/rejected pairs)
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Configure training
|
||||
training_args = RewardConfig(
|
||||
output_dir="Qwen2.5-0.5B-Reward",
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5
|
||||
)
|
||||
|
||||
# Train reward model
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**Step 3: PPO reinforcement learning**
|
||||
|
||||
Optimize policy using reward model:
|
||||
|
||||
```bash
|
||||
python -m trl.scripts.ppo \
|
||||
--model_name_or_path Qwen2.5-0.5B-SFT \
|
||||
--reward_model_path Qwen2.5-0.5B-Reward \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--output_dir Qwen2.5-0.5B-PPO \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--total_episodes 10000
|
||||
```
|
||||
|
||||
**Step 4: Evaluate**
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Load aligned model
|
||||
generator = pipeline("text-generation", model="Qwen2.5-0.5B-PPO")
|
||||
|
||||
# Test
|
||||
prompt = "Explain quantum computing to a 10-year-old"
|
||||
output = generator(prompt, max_length=200)[0]["generated_text"]
|
||||
print(output)
|
||||
```
|
||||
|
||||
### Workflow 2: Simple preference alignment with DPO
|
||||
|
||||
Align model with preferences without reward model.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
DPO Training:
|
||||
- [ ] Step 1: Prepare preference dataset
|
||||
- [ ] Step 2: Configure DPO
|
||||
- [ ] Step 3: Train with DPOTrainer
|
||||
- [ ] Step 4: Evaluate alignment
|
||||
```
|
||||
|
||||
**Step 1: Prepare preference dataset**
|
||||
|
||||
Dataset format:
|
||||
```json
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"chosen": "The capital of France is Paris.",
|
||||
"rejected": "I don't know."
|
||||
}
|
||||
```
|
||||
|
||||
Load dataset:
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
# Or load your own
|
||||
# dataset = load_dataset("json", data_files="preferences.json")
|
||||
```
|
||||
|
||||
**Step 2: Configure DPO**
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
config = DPOConfig(
|
||||
output_dir="Qwen2.5-0.5B-DPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=5e-7,
|
||||
beta=0.1, # KL penalty strength
|
||||
max_prompt_length=512,
|
||||
max_length=1024,
|
||||
logging_steps=10
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Train with DPOTrainer**
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**CLI alternative**:
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO \
|
||||
--per_device_train_batch_size 4 \
|
||||
--learning_rate 5e-7 \
|
||||
--beta 0.1
|
||||
```
|
||||
|
||||
### Workflow 3: Memory-efficient online RL with GRPO
|
||||
|
||||
Train with reinforcement learning using minimal memory.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
GRPO Training:
|
||||
- [ ] Step 1: Define reward function
|
||||
- [ ] Step 2: Configure GRPO
|
||||
- [ ] Step 3: Train with GRPOTrainer
|
||||
```
|
||||
|
||||
**Step 1: Define reward function**
|
||||
|
||||
```python
|
||||
def reward_function(completions, **kwargs):
|
||||
"""
|
||||
Compute rewards for completions.
|
||||
|
||||
Args:
|
||||
completions: List of generated texts
|
||||
|
||||
Returns:
|
||||
List of reward scores (floats)
|
||||
"""
|
||||
rewards = []
|
||||
for completion in completions:
|
||||
# Example: reward based on length and unique words
|
||||
score = len(completion.split()) # Favor longer responses
|
||||
score += len(set(completion.lower().split())) # Reward unique words
|
||||
rewards.append(score)
|
||||
return rewards
|
||||
```
|
||||
|
||||
Or use a reward model:
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
reward_model = pipeline("text-classification", model="reward-model-path")
|
||||
|
||||
def reward_from_model(completions, prompts, **kwargs):
|
||||
# Combine prompt + completion
|
||||
full_texts = [p + c for p, c in zip(prompts, completions)]
|
||||
# Get reward scores
|
||||
results = reward_model(full_texts)
|
||||
return [r["score"] for r in results]
|
||||
```
|
||||
|
||||
**Step 2: Configure GRPO**
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir="Qwen2-GRPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5,
|
||||
num_generations=4, # Generate 4 completions per prompt
|
||||
max_new_tokens=128
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Train with GRPOTrainer**
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load prompt-only dataset
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_function, # Your reward function
|
||||
args=config,
|
||||
train_dataset=dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**CLI**:
|
||||
```bash
|
||||
trl grpo \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--output_dir Qwen2-GRPO \
|
||||
--num_generations 4
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use TRL when:**
|
||||
- Need to align model with human preferences
|
||||
- Have preference data (chosen/rejected pairs)
|
||||
- Want to use reinforcement learning (PPO, GRPO)
|
||||
- Need reward model training
|
||||
- Doing RLHF (full pipeline)
|
||||
|
||||
**Method selection**:
|
||||
- **SFT**: Have prompt-completion pairs, want basic instruction following
|
||||
- **DPO**: Have preferences, want simple alignment (no reward model needed)
|
||||
- **PPO**: Have reward model, need maximum control over RL
|
||||
- **GRPO**: Memory-constrained, want online RL
|
||||
- **Reward Model**: Building RLHF pipeline, need to score generations
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **HuggingFace Trainer**: Basic fine-tuning without RL
|
||||
- **Axolotl**: YAML-based training configuration
|
||||
- **LitGPT**: Educational, minimal fine-tuning
|
||||
- **Unsloth**: Fast LoRA training
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: OOM during DPO training**
|
||||
|
||||
Reduce batch size and sequence length:
|
||||
```python
|
||||
config = DPOConfig(
|
||||
per_device_train_batch_size=1, # Reduce from 4
|
||||
max_length=512, # Reduce from 1024
|
||||
gradient_accumulation_steps=8 # Maintain effective batch
|
||||
)
|
||||
```
|
||||
|
||||
Or use gradient checkpointing:
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Issue: Poor alignment quality**
|
||||
|
||||
Tune beta parameter:
|
||||
```python
|
||||
# Higher beta = more conservative (stays closer to reference)
|
||||
config = DPOConfig(beta=0.5) # Default 0.1
|
||||
|
||||
# Lower beta = more aggressive alignment
|
||||
config = DPOConfig(beta=0.01)
|
||||
```
|
||||
|
||||
**Issue: Reward model not learning**
|
||||
|
||||
Check loss type and learning rate:
|
||||
```python
|
||||
config = RewardConfig(
|
||||
learning_rate=1e-5, # Try different LR
|
||||
num_train_epochs=3 # Train longer
|
||||
)
|
||||
```
|
||||
|
||||
Ensure preference dataset has clear winners:
|
||||
```python
|
||||
# Verify dataset
|
||||
print(dataset[0])
|
||||
# Should have clear chosen > rejected
|
||||
```
|
||||
|
||||
**Issue: PPO training unstable**
|
||||
|
||||
Adjust KL coefficient:
|
||||
```python
|
||||
config = PPOConfig(
|
||||
kl_coef=0.1, # Increase from 0.05
|
||||
cliprange=0.1 # Reduce from 0.2
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**SFT training guide**: See [references/sft-training.md](references/sft-training.md) for dataset formats, chat templates, packing strategies, and multi-GPU training.
|
||||
|
||||
**DPO variants**: See [references/dpo-variants.md](references/dpo-variants.md) for IPO, cDPO, RPO, and other DPO loss functions with recommended hyperparameters.
|
||||
|
||||
**Reward modeling**: See [references/reward-modeling.md](references/reward-modeling.md) for outcome vs process rewards, Bradley-Terry loss, and reward model evaluation.
|
||||
|
||||
**Online RL methods**: See [references/online-rl.md](references/online-rl.md) for PPO, GRPO, RLOO, and OnlineDPO with detailed configurations.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA required)
|
||||
- **VRAM**: Depends on model and method
|
||||
- SFT 7B: 16GB (with LoRA)
|
||||
- DPO 7B: 24GB (stores reference model)
|
||||
- PPO 7B: 40GB (policy + reward model)
|
||||
- GRPO 7B: 24GB (more memory efficient)
|
||||
- **Multi-GPU**: Supported via `accelerate`
|
||||
- **Mixed precision**: BF16 recommended (A100/H100)
|
||||
|
||||
**Memory optimization**:
|
||||
- Use LoRA/QLoRA for all methods
|
||||
- Enable gradient checkpointing
|
||||
- Use smaller batch sizes with gradient accumulation
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/trl/
|
||||
- GitHub: https://github.com/huggingface/trl
|
||||
- Papers:
|
||||
- "Training language models to follow instructions with human feedback" (InstructGPT, 2022)
|
||||
- "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" (DPO, 2023)
|
||||
- "Group Relative Policy Optimization" (GRPO, 2024)
|
||||
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|
||||
|
||||
|
||||
|
||||
227
skills/mlops/trl-fine-tuning/references/dpo-variants.md
Normal file
227
skills/mlops/trl-fine-tuning/references/dpo-variants.md
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
# DPO Variants
|
||||
|
||||
Complete guide to Direct Preference Optimization loss variants in TRL.
|
||||
|
||||
## Overview
|
||||
|
||||
DPO optimizes models using preference data (chosen/rejected pairs). TRL supports 10+ loss variants for different scenarios.
|
||||
|
||||
## Loss Types
|
||||
|
||||
### 1. Sigmoid (Standard DPO)
|
||||
|
||||
**Formula**: `-log(sigmoid(β * logits))`
|
||||
|
||||
**When to use**: Default choice, general preference alignment
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="sigmoid",
|
||||
beta=0.1, # KL penalty
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=1e-6
|
||||
)
|
||||
```
|
||||
|
||||
### 2. IPO (Identity Policy Optimization)
|
||||
|
||||
**Formula**: `(logits - 1/(2β))²`
|
||||
|
||||
**When to use**: Better theoretical foundation, reduce overfitting
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="ipo",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=90,
|
||||
learning_rate=1e-2
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Hinge (SLiC)
|
||||
|
||||
**Formula**: `ReLU(1 - β * logits)`
|
||||
|
||||
**When to use**: Margin-based objective
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="hinge",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=512,
|
||||
learning_rate=1e-4
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Robust DPO
|
||||
|
||||
**Formula**: Sigmoid with label smoothing for noise robustness
|
||||
|
||||
**When to use**: Noisy preference labels
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="robust",
|
||||
beta=0.01,
|
||||
label_smoothing=0.1, # Noise probability
|
||||
per_device_train_batch_size=16,
|
||||
learning_rate=1e-3,
|
||||
max_prompt_length=128,
|
||||
max_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 5. BCO Pair (Binary Classification)
|
||||
|
||||
**Formula**: Train binary classifier (chosen=1, rejected=0)
|
||||
|
||||
**When to use**: Pairwise preference data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="bco_pair",
|
||||
beta=0.01,
|
||||
per_device_train_batch_size=128,
|
||||
learning_rate=5e-7,
|
||||
max_prompt_length=1536,
|
||||
max_completion_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 6. SPPO Hard
|
||||
|
||||
**Formula**: Push chosen→0.5, rejected→-0.5
|
||||
|
||||
**When to use**: Nash equilibrium, sparse data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="sppo_hard",
|
||||
beta=0.1
|
||||
)
|
||||
```
|
||||
|
||||
### 7. DiscoPOP
|
||||
|
||||
**Formula**: Log-Ratio Modulated Loss
|
||||
|
||||
**When to use**: Automated loss discovery
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="discopop",
|
||||
beta=0.05,
|
||||
discopop_tau=0.05,
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=5e-7
|
||||
)
|
||||
```
|
||||
|
||||
### 8. APO Zero
|
||||
|
||||
**Formula**: Increase chosen, decrease rejected likelihood
|
||||
|
||||
**When to use**: Model worse than winning outputs
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="apo_zero",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=2e-7,
|
||||
max_prompt_length=512,
|
||||
max_completion_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 9. APO Down
|
||||
|
||||
**Formula**: Decrease both, emphasize rejected reduction
|
||||
|
||||
**When to use**: Model better than winning outputs
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="apo_down",
|
||||
beta=0.1,
|
||||
# Same hyperparameters as apo_zero
|
||||
)
|
||||
```
|
||||
|
||||
### 10. AOT & AOT Pair
|
||||
|
||||
**Formula**: Distributional alignment via stochastic dominance
|
||||
|
||||
**When to use**:
|
||||
- `aot_pair`: Paired preference data
|
||||
- `aot`: Unpaired data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="aot_pair", # or "aot"
|
||||
beta=0.1,
|
||||
label_smoothing=0.0
|
||||
)
|
||||
```
|
||||
|
||||
## Multi-Loss Training
|
||||
|
||||
Combine multiple losses:
|
||||
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type=["sigmoid", "ipo"],
|
||||
loss_weights=[0.7, 0.3], # Weighted combination
|
||||
beta=0.1
|
||||
)
|
||||
```
|
||||
|
||||
## Key Parameters
|
||||
|
||||
### Beta (β)
|
||||
|
||||
Controls deviation from reference model:
|
||||
- **Higher** (0.5): More conservative, stays close to reference
|
||||
- **Lower** (0.01): More aggressive alignment
|
||||
- **Default**: 0.1
|
||||
|
||||
### Label Smoothing
|
||||
|
||||
For robust DPO:
|
||||
- **0.0**: No smoothing (default)
|
||||
- **0.1-0.3**: Moderate noise robustness
|
||||
- **0.5**: Maximum noise tolerance
|
||||
|
||||
### Max Lengths
|
||||
|
||||
- `max_prompt_length`: 128-1536
|
||||
- `max_completion_length`: 128-512
|
||||
- `max_length`: Total sequence (1024-2048)
|
||||
|
||||
## Comparison Table
|
||||
|
||||
| Loss | Speed | Stability | Best For |
|
||||
|------|-------|-----------|----------|
|
||||
| Sigmoid | Fast | Good | **General use** |
|
||||
| IPO | Fast | Better | Overfitting issues |
|
||||
| Hinge | Fast | Good | Margin objectives |
|
||||
| Robust | Fast | Best | Noisy data |
|
||||
| BCO | Medium | Good | Binary classification |
|
||||
| DiscoPOP | Fast | Good | New architectures |
|
||||
| APO | Fast | Good | Model quality matching |
|
||||
|
||||
## References
|
||||
|
||||
- DPO paper: https://arxiv.org/abs/2305.18290
|
||||
- IPO paper: https://arxiv.org/abs/2310.12036
|
||||
- TRL docs: https://huggingface.co/docs/trl/dpo_trainer
|
||||
82
skills/mlops/trl-fine-tuning/references/online-rl.md
Normal file
82
skills/mlops/trl-fine-tuning/references/online-rl.md
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
# Online RL Methods
|
||||
|
||||
Guide to online reinforcement learning with PPO, GRPO, RLOO, and OnlineDPO.
|
||||
|
||||
## Overview
|
||||
|
||||
Online RL generates completions during training and optimizes based on rewards.
|
||||
|
||||
## PPO (Proximal Policy Optimization)
|
||||
|
||||
Classic RL algorithm for LLM alignment.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
python -m trl.scripts.ppo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--reward_model_path reward-model \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--output_dir model-ppo \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--total_episodes 10000 \
|
||||
--num_ppo_epochs 4 \
|
||||
--kl_coef 0.05
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `kl_coef`: KL penalty (0.05-0.2)
|
||||
- `num_ppo_epochs`: Epochs per batch (2-4)
|
||||
- `cliprange`: PPO clip (0.1-0.3)
|
||||
- `vf_coef`: Value function coef (0.1)
|
||||
|
||||
## GRPO (Group Relative Policy Optimization)
|
||||
|
||||
Memory-efficient online RL.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Define reward function
|
||||
def reward_func(completions, **kwargs):
|
||||
return [len(set(c.split())) for c in completions]
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir="model-grpo",
|
||||
num_generations=4, # Completions per prompt
|
||||
max_new_tokens=128
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_func,
|
||||
args=config,
|
||||
train_dataset=load_dataset("trl-lib/tldr", split="train")
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `num_generations`: 2-8 completions
|
||||
- `max_new_tokens`: 64-256
|
||||
- Learning rate: 1e-5 to 1e-4
|
||||
|
||||
## Memory Comparison
|
||||
|
||||
| Method | Memory (7B) | Speed | Use Case |
|
||||
|--------|-------------|-------|----------|
|
||||
| PPO | 40GB | Medium | Maximum control |
|
||||
| GRPO | 24GB | Fast | **Memory-constrained** |
|
||||
| OnlineDPO | 28GB | Fast | No reward model |
|
||||
|
||||
## References
|
||||
|
||||
- PPO paper: https://arxiv.org/abs/1707.06347
|
||||
- GRPO paper: https://arxiv.org/abs/2402.03300
|
||||
- TRL docs: https://huggingface.co/docs/trl/
|
||||
122
skills/mlops/trl-fine-tuning/references/reward-modeling.md
Normal file
122
skills/mlops/trl-fine-tuning/references/reward-modeling.md
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
# Reward Modeling
|
||||
|
||||
Guide to training reward models with TRL for RLHF pipelines.
|
||||
|
||||
## Overview
|
||||
|
||||
Reward models score completions based on human preferences. Used in:
|
||||
- PPO training (RL feedback)
|
||||
- GRPO online RL
|
||||
- Completion ranking
|
||||
|
||||
## Basic Training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model (num_labels=1 for single reward score)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
num_labels=1
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
# Load preference dataset (chosen/rejected pairs)
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Configure
|
||||
config = RewardConfig(
|
||||
output_dir="Qwen2.5-Reward",
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
Required fields:
|
||||
```json
|
||||
{
|
||||
"prompt": "Question or instruction",
|
||||
"chosen": "Better response",
|
||||
"rejected": "Worse response"
|
||||
}
|
||||
```
|
||||
|
||||
## Bradley-Terry Loss
|
||||
|
||||
Default loss function:
|
||||
```
|
||||
loss = -log(sigmoid(reward_chosen - reward_rejected))
|
||||
```
|
||||
|
||||
Learns to score chosen > rejected.
|
||||
|
||||
## Using Reward Models
|
||||
|
||||
### Inference
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Load trained reward model
|
||||
reward_pipe = pipeline("text-classification", model="Qwen2.5-Reward")
|
||||
|
||||
# Score completions
|
||||
texts = ["Good answer", "Bad answer"]
|
||||
scores = reward_pipe(texts)
|
||||
print(scores) # Higher score = better
|
||||
```
|
||||
|
||||
### In PPO
|
||||
|
||||
```python
|
||||
from trl import PPOTrainer, PPOConfig
|
||||
|
||||
config = PPOConfig(
|
||||
reward_model_path="Qwen2.5-Reward" # Use trained reward model
|
||||
)
|
||||
|
||||
trainer = PPOTrainer(
|
||||
model=policy_model,
|
||||
config=config,
|
||||
# Reward model loaded automatically
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
| Model Size | Learning Rate | Batch Size | Epochs |
|
||||
|------------|---------------|------------|--------|
|
||||
| <1B | 2e-5 | 4-8 | 1-2 |
|
||||
| 1-7B | 1e-5 | 2-4 | 1 |
|
||||
| 7-13B | 5e-6 | 1-2 | 1 |
|
||||
|
||||
## Evaluation
|
||||
|
||||
Check reward separation:
|
||||
```python
|
||||
# Chosen should score higher than rejected
|
||||
chosen_rewards = model(**chosen_inputs).logits
|
||||
rejected_rewards = model(**rejected_inputs).logits
|
||||
|
||||
accuracy = (chosen_rewards > rejected_rewards).float().mean()
|
||||
print(f"Accuracy: {accuracy:.2%}") # Target: >80%
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- InstructGPT paper: https://arxiv.org/abs/2203.02155
|
||||
- TRL docs: https://huggingface.co/docs/trl/reward_trainer
|
||||
168
skills/mlops/trl-fine-tuning/references/sft-training.md
Normal file
168
skills/mlops/trl-fine-tuning/references/sft-training.md
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
# SFT Training Guide
|
||||
|
||||
Complete guide to Supervised Fine-Tuning (SFT) with TRL for instruction tuning and task-specific fine-tuning.
|
||||
|
||||
## Overview
|
||||
|
||||
SFT trains models on input-output pairs to minimize cross-entropy loss. Use for:
|
||||
- Instruction following
|
||||
- Task-specific fine-tuning
|
||||
- Chatbot training
|
||||
- Domain adaptation
|
||||
|
||||
## Dataset Formats
|
||||
|
||||
### Format 1: Prompt-Completion
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"completion": "The capital of France is Paris."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Format 2: Conversational (ChatML)
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
{"role": "assistant", "content": "Python is a programming language."}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Format 3: Text-only
|
||||
|
||||
```json
|
||||
[
|
||||
{"text": "User: Hello\nAssistant: Hi! How can I help?"}
|
||||
]
|
||||
```
|
||||
|
||||
## Basic Training
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure
|
||||
config = SFTConfig(
|
||||
output_dir="Qwen2.5-SFT",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=2e-5,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Chat Templates
|
||||
|
||||
Apply chat templates automatically:
|
||||
|
||||
```python
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset, # Messages format
|
||||
tokenizer=tokenizer
|
||||
# Chat template applied automatically
|
||||
)
|
||||
```
|
||||
|
||||
Or manually:
|
||||
```python
|
||||
def format_chat(example):
|
||||
messages = example["messages"]
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
return {"text": text}
|
||||
|
||||
dataset = dataset.map(format_chat)
|
||||
```
|
||||
|
||||
## Packing for Efficiency
|
||||
|
||||
Pack multiple sequences into one to maximize GPU utilization:
|
||||
|
||||
```python
|
||||
config = SFTConfig(
|
||||
packing=True, # Enable packing
|
||||
max_seq_length=2048,
|
||||
dataset_text_field="text"
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**: 2-3× faster training
|
||||
**Trade-off**: Slightly more complex batching
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
```bash
|
||||
accelerate launch --num_processes 4 train_sft.py
|
||||
```
|
||||
|
||||
Or with config:
|
||||
```python
|
||||
config = SFTConfig(
|
||||
output_dir="model-sft",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=1
|
||||
)
|
||||
```
|
||||
|
||||
## LoRA Fine-Tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config # Add LoRA
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
| Model Size | Learning Rate | Batch Size | Epochs |
|
||||
|------------|---------------|------------|--------|
|
||||
| <1B | 5e-5 | 8-16 | 1-3 |
|
||||
| 1-7B | 2e-5 | 4-8 | 1-2 |
|
||||
| 7-13B | 1e-5 | 2-4 | 1 |
|
||||
| 13B+ | 5e-6 | 1-2 | 1 |
|
||||
|
||||
## References
|
||||
|
||||
- TRL docs: https://huggingface.co/docs/trl/sft_trainer
|
||||
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|
||||
320
skills/mlops/whisper/SKILL.md
Normal file
320
skills/mlops/whisper/SKILL.md
Normal file
|
|
@ -0,0 +1,320 @@
|
|||
---
|
||||
name: whisper
|
||||
description: OpenAI's general-purpose speech recognition model. Supports 99 languages, transcription, translation to English, and language identification. Six model sizes from tiny (39M params) to large (1550M params). Use for speech-to-text, podcast transcription, or multilingual audio processing. Best for robust, multilingual ASR.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [openai-whisper, transformers, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Whisper, Speech Recognition, ASR, Multimodal, Multilingual, OpenAI, Speech-To-Text, Transcription, Translation, Audio Processing]
|
||||
|
||||
---
|
||||
|
||||
# Whisper - Robust Speech Recognition
|
||||
|
||||
OpenAI's multilingual speech recognition model.
|
||||
|
||||
## When to use Whisper
|
||||
|
||||
**Use when:**
|
||||
- Speech-to-text transcription (99 languages)
|
||||
- Podcast/video transcription
|
||||
- Meeting notes automation
|
||||
- Translation to English
|
||||
- Noisy audio transcription
|
||||
- Multilingual audio processing
|
||||
|
||||
**Metrics**:
|
||||
- **72,900+ GitHub stars**
|
||||
- 99 languages supported
|
||||
- Trained on 680,000 hours of audio
|
||||
- MIT License
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **AssemblyAI**: Managed API, speaker diarization
|
||||
- **Deepgram**: Real-time streaming ASR
|
||||
- **Google Speech-to-Text**: Cloud-based
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Requires Python 3.8-3.11
|
||||
pip install -U openai-whisper
|
||||
|
||||
# Requires ffmpeg
|
||||
# macOS: brew install ffmpeg
|
||||
# Ubuntu: sudo apt install ffmpeg
|
||||
# Windows: choco install ffmpeg
|
||||
```
|
||||
|
||||
### Basic transcription
|
||||
|
||||
```python
|
||||
import whisper
|
||||
|
||||
# Load model
|
||||
model = whisper.load_model("base")
|
||||
|
||||
# Transcribe
|
||||
result = model.transcribe("audio.mp3")
|
||||
|
||||
# Print text
|
||||
print(result["text"])
|
||||
|
||||
# Access segments
|
||||
for segment in result["segments"]:
|
||||
print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['text']}")
|
||||
```
|
||||
|
||||
## Model sizes
|
||||
|
||||
```python
|
||||
# Available models
|
||||
models = ["tiny", "base", "small", "medium", "large", "turbo"]
|
||||
|
||||
# Load specific model
|
||||
model = whisper.load_model("turbo") # Fastest, good quality
|
||||
```
|
||||
|
||||
| Model | Parameters | English-only | Multilingual | Speed | VRAM |
|
||||
|-------|------------|--------------|--------------|-------|------|
|
||||
| tiny | 39M | ✓ | ✓ | ~32x | ~1 GB |
|
||||
| base | 74M | ✓ | ✓ | ~16x | ~1 GB |
|
||||
| small | 244M | ✓ | ✓ | ~6x | ~2 GB |
|
||||
| medium | 769M | ✓ | ✓ | ~2x | ~5 GB |
|
||||
| large | 1550M | ✗ | ✓ | 1x | ~10 GB |
|
||||
| turbo | 809M | ✗ | ✓ | ~8x | ~6 GB |
|
||||
|
||||
**Recommendation**: Use `turbo` for best speed/quality, `base` for prototyping
|
||||
|
||||
## Transcription options
|
||||
|
||||
### Language specification
|
||||
|
||||
```python
|
||||
# Auto-detect language
|
||||
result = model.transcribe("audio.mp3")
|
||||
|
||||
# Specify language (faster)
|
||||
result = model.transcribe("audio.mp3", language="en")
|
||||
|
||||
# Supported: en, es, fr, de, it, pt, ru, ja, ko, zh, and 89 more
|
||||
```
|
||||
|
||||
### Task selection
|
||||
|
||||
```python
|
||||
# Transcription (default)
|
||||
result = model.transcribe("audio.mp3", task="transcribe")
|
||||
|
||||
# Translation to English
|
||||
result = model.transcribe("spanish.mp3", task="translate")
|
||||
# Input: Spanish audio → Output: English text
|
||||
```
|
||||
|
||||
### Initial prompt
|
||||
|
||||
```python
|
||||
# Improve accuracy with context
|
||||
result = model.transcribe(
|
||||
"audio.mp3",
|
||||
initial_prompt="This is a technical podcast about machine learning and AI."
|
||||
)
|
||||
|
||||
# Helps with:
|
||||
# - Technical terms
|
||||
# - Proper nouns
|
||||
# - Domain-specific vocabulary
|
||||
```
|
||||
|
||||
### Timestamps
|
||||
|
||||
```python
|
||||
# Word-level timestamps
|
||||
result = model.transcribe("audio.mp3", word_timestamps=True)
|
||||
|
||||
for segment in result["segments"]:
|
||||
for word in segment["words"]:
|
||||
print(f"{word['word']} ({word['start']:.2f}s - {word['end']:.2f}s)")
|
||||
```
|
||||
|
||||
### Temperature fallback
|
||||
|
||||
```python
|
||||
# Retry with different temperatures if confidence low
|
||||
result = model.transcribe(
|
||||
"audio.mp3",
|
||||
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
|
||||
)
|
||||
```
|
||||
|
||||
## Command line usage
|
||||
|
||||
```bash
|
||||
# Basic transcription
|
||||
whisper audio.mp3
|
||||
|
||||
# Specify model
|
||||
whisper audio.mp3 --model turbo
|
||||
|
||||
# Output formats
|
||||
whisper audio.mp3 --output_format txt # Plain text
|
||||
whisper audio.mp3 --output_format srt # Subtitles
|
||||
whisper audio.mp3 --output_format vtt # WebVTT
|
||||
whisper audio.mp3 --output_format json # JSON with timestamps
|
||||
|
||||
# Language
|
||||
whisper audio.mp3 --language Spanish
|
||||
|
||||
# Translation
|
||||
whisper spanish.mp3 --task translate
|
||||
```
|
||||
|
||||
## Batch processing
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
audio_files = ["file1.mp3", "file2.mp3", "file3.mp3"]
|
||||
|
||||
for audio_file in audio_files:
|
||||
print(f"Transcribing {audio_file}...")
|
||||
result = model.transcribe(audio_file)
|
||||
|
||||
# Save to file
|
||||
output_file = audio_file.replace(".mp3", ".txt")
|
||||
with open(output_file, "w") as f:
|
||||
f.write(result["text"])
|
||||
```
|
||||
|
||||
## Real-time transcription
|
||||
|
||||
```python
|
||||
# For streaming audio, use faster-whisper
|
||||
# pip install faster-whisper
|
||||
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
model = WhisperModel("base", device="cuda", compute_type="float16")
|
||||
|
||||
# Transcribe with streaming
|
||||
segments, info = model.transcribe("audio.mp3", beam_size=5)
|
||||
|
||||
for segment in segments:
|
||||
print(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
```python
|
||||
import whisper
|
||||
|
||||
# Automatically uses GPU if available
|
||||
model = whisper.load_model("turbo")
|
||||
|
||||
# Force CPU
|
||||
model = whisper.load_model("turbo", device="cpu")
|
||||
|
||||
# Force GPU
|
||||
model = whisper.load_model("turbo", device="cuda")
|
||||
|
||||
# 10-20× faster on GPU
|
||||
```
|
||||
|
||||
## Integration with other tools
|
||||
|
||||
### Subtitle generation
|
||||
|
||||
```bash
|
||||
# Generate SRT subtitles
|
||||
whisper video.mp4 --output_format srt --language English
|
||||
|
||||
# Output: video.srt
|
||||
```
|
||||
|
||||
### With LangChain
|
||||
|
||||
```python
|
||||
from langchain.document_loaders import WhisperTranscriptionLoader
|
||||
|
||||
loader = WhisperTranscriptionLoader(file_path="audio.mp3")
|
||||
docs = loader.load()
|
||||
|
||||
# Use transcription in RAG
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
vectorstore = Chroma.from_documents(docs, OpenAIEmbeddings())
|
||||
```
|
||||
|
||||
### Extract audio from video
|
||||
|
||||
```bash
|
||||
# Use ffmpeg to extract audio
|
||||
ffmpeg -i video.mp4 -vn -acodec pcm_s16le audio.wav
|
||||
|
||||
# Then transcribe
|
||||
whisper audio.wav
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use turbo model** - Best speed/quality for English
|
||||
2. **Specify language** - Faster than auto-detect
|
||||
3. **Add initial prompt** - Improves technical terms
|
||||
4. **Use GPU** - 10-20× faster
|
||||
5. **Batch process** - More efficient
|
||||
6. **Convert to WAV** - Better compatibility
|
||||
7. **Split long audio** - <30 min chunks
|
||||
8. **Check language support** - Quality varies by language
|
||||
9. **Use faster-whisper** - 4× faster than openai-whisper
|
||||
10. **Monitor VRAM** - Scale model size to hardware
|
||||
|
||||
## Performance
|
||||
|
||||
| Model | Real-time factor (CPU) | Real-time factor (GPU) |
|
||||
|-------|------------------------|------------------------|
|
||||
| tiny | ~0.32 | ~0.01 |
|
||||
| base | ~0.16 | ~0.01 |
|
||||
| turbo | ~0.08 | ~0.01 |
|
||||
| large | ~1.0 | ~0.05 |
|
||||
|
||||
*Real-time factor: 0.1 = 10× faster than real-time*
|
||||
|
||||
## Language support
|
||||
|
||||
Top-supported languages:
|
||||
- English (en)
|
||||
- Spanish (es)
|
||||
- French (fr)
|
||||
- German (de)
|
||||
- Italian (it)
|
||||
- Portuguese (pt)
|
||||
- Russian (ru)
|
||||
- Japanese (ja)
|
||||
- Korean (ko)
|
||||
- Chinese (zh)
|
||||
|
||||
Full list: 99 languages total
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Hallucinations** - May repeat or invent text
|
||||
2. **Long-form accuracy** - Degrades on >30 min audio
|
||||
3. **Speaker identification** - No diarization
|
||||
4. **Accents** - Quality varies
|
||||
5. **Background noise** - Can affect accuracy
|
||||
6. **Real-time latency** - Not suitable for live captioning
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/openai/whisper ⭐ 72,900+
|
||||
- **Paper**: https://arxiv.org/abs/2212.04356
|
||||
- **Model Card**: https://github.com/openai/whisper/blob/main/model-card.md
|
||||
- **Colab**: Available in repo
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
189
skills/mlops/whisper/references/languages.md
Normal file
189
skills/mlops/whisper/references/languages.md
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
# Whisper Language Support Guide
|
||||
|
||||
Complete guide to Whisper's multilingual capabilities.
|
||||
|
||||
## Supported languages (99 total)
|
||||
|
||||
### Top-tier support (WER < 10%)
|
||||
|
||||
- English (en)
|
||||
- Spanish (es)
|
||||
- French (fr)
|
||||
- German (de)
|
||||
- Italian (it)
|
||||
- Portuguese (pt)
|
||||
- Dutch (nl)
|
||||
- Polish (pl)
|
||||
- Russian (ru)
|
||||
- Japanese (ja)
|
||||
- Korean (ko)
|
||||
- Chinese (zh)
|
||||
|
||||
### Good support (WER 10-20%)
|
||||
|
||||
- Arabic (ar)
|
||||
- Turkish (tr)
|
||||
- Vietnamese (vi)
|
||||
- Swedish (sv)
|
||||
- Finnish (fi)
|
||||
- Czech (cs)
|
||||
- Romanian (ro)
|
||||
- Hungarian (hu)
|
||||
- Danish (da)
|
||||
- Norwegian (no)
|
||||
- Thai (th)
|
||||
- Hebrew (he)
|
||||
- Greek (el)
|
||||
- Indonesian (id)
|
||||
- Malay (ms)
|
||||
|
||||
### Full list (99 languages)
|
||||
|
||||
Afrikaans, Albanian, Amharic, Arabic, Armenian, Assamese, Azerbaijani, Bashkir, Basque, Belarusian, Bengali, Bosnian, Breton, Bulgarian, Burmese, Cantonese, Catalan, Chinese, Croatian, Czech, Danish, Dutch, English, Estonian, Faroese, Finnish, French, Galician, Georgian, German, Greek, Gujarati, Haitian Creole, Hausa, Hawaiian, Hebrew, Hindi, Hungarian, Icelandic, Indonesian, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Lao, Latin, Latvian, Lingala, Lithuanian, Luxembourgish, Macedonian, Malagasy, Malay, Malayalam, Maltese, Maori, Marathi, Moldavian, Mongolian, Myanmar, Nepali, Norwegian, Nynorsk, Occitan, Pashto, Persian, Polish, Portuguese, Punjabi, Pushto, Romanian, Russian, Sanskrit, Serbian, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tagalog, Tajik, Tamil, Tatar, Telugu, Thai, Tibetan, Turkish, Turkmen, Ukrainian, Urdu, Uzbek, Vietnamese, Welsh, Yiddish, Yoruba
|
||||
|
||||
## Usage examples
|
||||
|
||||
### Auto-detect language
|
||||
|
||||
```python
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model("turbo")
|
||||
|
||||
# Auto-detect language
|
||||
result = model.transcribe("audio.mp3")
|
||||
|
||||
print(f"Detected language: {result['language']}")
|
||||
print(f"Text: {result['text']}")
|
||||
```
|
||||
|
||||
### Specify language (faster)
|
||||
|
||||
```python
|
||||
# Specify language for faster transcription
|
||||
result = model.transcribe("audio.mp3", language="es") # Spanish
|
||||
result = model.transcribe("audio.mp3", language="fr") # French
|
||||
result = model.transcribe("audio.mp3", language="ja") # Japanese
|
||||
```
|
||||
|
||||
### Translation to English
|
||||
|
||||
```python
|
||||
# Translate any language to English
|
||||
result = model.transcribe(
|
||||
"spanish_audio.mp3",
|
||||
task="translate" # Translates to English
|
||||
)
|
||||
|
||||
print(f"Original language: {result['language']}")
|
||||
print(f"English translation: {result['text']}")
|
||||
```
|
||||
|
||||
## Language-specific tips
|
||||
|
||||
### Chinese
|
||||
|
||||
```python
|
||||
# Chinese works well with larger models
|
||||
model = whisper.load_model("large")
|
||||
|
||||
result = model.transcribe(
|
||||
"chinese_audio.mp3",
|
||||
language="zh",
|
||||
initial_prompt="这是一段关于技术的讨论" # Context helps
|
||||
)
|
||||
```
|
||||
|
||||
### Japanese
|
||||
|
||||
```python
|
||||
# Japanese benefits from initial prompt
|
||||
result = model.transcribe(
|
||||
"japanese_audio.mp3",
|
||||
language="ja",
|
||||
initial_prompt="これは技術的な会議の録音です"
|
||||
)
|
||||
```
|
||||
|
||||
### Arabic
|
||||
|
||||
```python
|
||||
# Arabic: Use large model for best results
|
||||
model = whisper.load_model("large")
|
||||
|
||||
result = model.transcribe(
|
||||
"arabic_audio.mp3",
|
||||
language="ar"
|
||||
)
|
||||
```
|
||||
|
||||
## Model size recommendations
|
||||
|
||||
| Language Tier | Recommended Model | WER |
|
||||
|---------------|-------------------|-----|
|
||||
| Top-tier (en, es, fr, de) | base/turbo | < 10% |
|
||||
| Good (ar, tr, vi) | medium/large | 10-20% |
|
||||
| Lower-resource | large | 20-30% |
|
||||
|
||||
## Performance by language
|
||||
|
||||
### English
|
||||
|
||||
- **tiny**: WER ~15%
|
||||
- **base**: WER ~8%
|
||||
- **small**: WER ~5%
|
||||
- **medium**: WER ~4%
|
||||
- **large**: WER ~3%
|
||||
- **turbo**: WER ~3.5%
|
||||
|
||||
### Spanish
|
||||
|
||||
- **tiny**: WER ~20%
|
||||
- **base**: WER ~12%
|
||||
- **medium**: WER ~6%
|
||||
- **large**: WER ~4%
|
||||
|
||||
### Chinese
|
||||
|
||||
- **small**: WER ~15%
|
||||
- **medium**: WER ~8%
|
||||
- **large**: WER ~5%
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use English-only models** - Better for small models (tiny/base)
|
||||
2. **Specify language** - Faster than auto-detect
|
||||
3. **Add initial prompt** - Improves accuracy for technical terms
|
||||
4. **Use larger models** - For low-resource languages
|
||||
5. **Test on sample** - Quality varies by accent/dialect
|
||||
6. **Consider audio quality** - Clear audio = better results
|
||||
7. **Check language codes** - Use ISO 639-1 codes (2 letters)
|
||||
|
||||
## Language detection
|
||||
|
||||
```python
|
||||
# Detect language only (no transcription)
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model("base")
|
||||
|
||||
# Load audio
|
||||
audio = whisper.load_audio("audio.mp3")
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
|
||||
# Make log-Mel spectrogram
|
||||
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
||||
|
||||
# Detect language
|
||||
_, probs = model.detect_language(mel)
|
||||
detected_language = max(probs, key=probs.get)
|
||||
|
||||
print(f"Detected language: {detected_language}")
|
||||
print(f"Confidence: {probs[detected_language]:.2%}")
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Paper**: https://arxiv.org/abs/2212.04356
|
||||
- **GitHub**: https://github.com/openai/whisper
|
||||
- **Model Card**: https://github.com/openai/whisper/blob/main/model-card.md
|
||||
Loading…
Add table
Add a link
Reference in a new issue