mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-27 06:11:40 +00:00
chore(skills): move heavy training skills + outlines to optional-skills (#22912)
These skills require heavy GPU/CUDA stacks or are niche enough that they shouldn't be active by default. Moved to optional-skills/ where users opt-in via `hermes skills install official/...`. Moved: - mlops/training/axolotl - mlops/training/trl-fine-tuning - mlops/training/unsloth - mlops/inference/outlines Counts: 91 -> 87 built-in, 72 -> 76 optional. Auto-regenerated docs (per-skill pages + catalogs) reflect the move.
This commit is contained in:
parent
4375b82cd9
commit
ded194eb6a
27 changed files with 18 additions and 18 deletions
656
optional-skills/mlops/inference/outlines/SKILL.md
Normal file
656
optional-skills/mlops/inference/outlines/SKILL.md
Normal file
|
|
@ -0,0 +1,656 @@
|
|||
---
|
||||
name: outlines
|
||||
description: "Outlines: structured JSON/regex/Pydantic LLM generation."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [outlines, transformers, vllm, pydantic]
|
||||
platforms: [linux, macos, windows]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Prompt Engineering, Outlines, Structured Generation, JSON Schema, Pydantic, Local Models, Grammar-Based Generation, vLLM, Transformers, Type Safety]
|
||||
|
||||
---
|
||||
|
||||
# Outlines: Structured Text Generation
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Outlines when you need to:
|
||||
- **Guarantee valid JSON/XML/code** structure during generation
|
||||
- **Use Pydantic models** for type-safe outputs
|
||||
- **Support local models** (Transformers, llama.cpp, vLLM)
|
||||
- **Maximize inference speed** with zero-overhead structured generation
|
||||
- **Generate against JSON schemas** automatically
|
||||
- **Control token sampling** at the grammar level
|
||||
|
||||
**GitHub Stars**: 8,000+ | **From**: dottxt.ai (formerly .txt)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base installation
|
||||
pip install outlines
|
||||
|
||||
# With specific backends
|
||||
pip install outlines transformers # Hugging Face models
|
||||
pip install outlines llama-cpp-python # llama.cpp
|
||||
pip install outlines vllm # vLLM for high-throughput
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Classification
|
||||
|
||||
```python
|
||||
import outlines
|
||||
from typing import Literal
|
||||
|
||||
# Load model
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Generate with type constraint
|
||||
prompt = "Sentiment of 'This product is amazing!': "
|
||||
generator = outlines.generate.choice(model, ["positive", "negative", "neutral"])
|
||||
sentiment = generator(prompt)
|
||||
|
||||
print(sentiment) # "positive" (guaranteed one of these)
|
||||
```
|
||||
|
||||
### With Pydantic Models
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
import outlines
|
||||
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Generate structured output
|
||||
prompt = "Extract user: John Doe, 30 years old, john@example.com"
|
||||
generator = outlines.generate.json(model, User)
|
||||
user = generator(prompt)
|
||||
|
||||
print(user.name) # "John Doe"
|
||||
print(user.age) # 30
|
||||
print(user.email) # "john@example.com"
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Constrained Token Sampling
|
||||
|
||||
Outlines uses Finite State Machines (FSM) to constrain token generation at the logit level.
|
||||
|
||||
**How it works:**
|
||||
1. Convert schema (JSON/Pydantic/regex) to context-free grammar (CFG)
|
||||
2. Transform CFG into Finite State Machine (FSM)
|
||||
3. Filter invalid tokens at each step during generation
|
||||
4. Fast-forward when only one valid token exists
|
||||
|
||||
**Benefits:**
|
||||
- **Zero overhead**: Filtering happens at token level
|
||||
- **Speed improvement**: Fast-forward through deterministic paths
|
||||
- **Guaranteed validity**: Invalid outputs impossible
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Pydantic model -> JSON schema -> CFG -> FSM
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Behind the scenes:
|
||||
# 1. Person -> JSON schema
|
||||
# 2. JSON schema -> CFG
|
||||
# 3. CFG -> FSM
|
||||
# 4. FSM filters tokens during generation
|
||||
|
||||
generator = outlines.generate.json(model, Person)
|
||||
result = generator("Generate person: Alice, 25")
|
||||
```
|
||||
|
||||
### 2. Structured Generators
|
||||
|
||||
Outlines provides specialized generators for different output types.
|
||||
|
||||
#### Choice Generator
|
||||
|
||||
```python
|
||||
# Multiple choice selection
|
||||
generator = outlines.generate.choice(
|
||||
model,
|
||||
["positive", "negative", "neutral"]
|
||||
)
|
||||
|
||||
sentiment = generator("Review: This is great!")
|
||||
# Result: One of the three choices
|
||||
```
|
||||
|
||||
#### JSON Generator
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
in_stock: bool
|
||||
|
||||
# Generate valid JSON matching schema
|
||||
generator = outlines.generate.json(model, Product)
|
||||
product = generator("Extract: iPhone 15, $999, available")
|
||||
|
||||
# Guaranteed valid Product instance
|
||||
print(type(product)) # <class '__main__.Product'>
|
||||
```
|
||||
|
||||
#### Regex Generator
|
||||
|
||||
```python
|
||||
# Generate text matching regex
|
||||
generator = outlines.generate.regex(
|
||||
model,
|
||||
r"[0-9]{3}-[0-9]{3}-[0-9]{4}" # Phone number pattern
|
||||
)
|
||||
|
||||
phone = generator("Generate phone number:")
|
||||
# Result: "555-123-4567" (guaranteed to match pattern)
|
||||
```
|
||||
|
||||
#### Integer/Float Generators
|
||||
|
||||
```python
|
||||
# Generate specific numeric types
|
||||
int_generator = outlines.generate.integer(model)
|
||||
age = int_generator("Person's age:") # Guaranteed integer
|
||||
|
||||
float_generator = outlines.generate.float(model)
|
||||
price = float_generator("Product price:") # Guaranteed float
|
||||
```
|
||||
|
||||
### 3. Model Backends
|
||||
|
||||
Outlines supports multiple local and API-based backends.
|
||||
|
||||
#### Transformers (Hugging Face)
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Load from Hugging Face
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda" # Or "cpu"
|
||||
)
|
||||
|
||||
# Use with any generator
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
#### llama.cpp
|
||||
|
||||
```python
|
||||
# Load GGUF model
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/llama-3.1-8b-instruct.Q4_K_M.gguf",
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
#### vLLM (High Throughput)
|
||||
|
||||
```python
|
||||
# For production deployments
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
tensor_parallel_size=2 # Multi-GPU
|
||||
)
|
||||
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
#### OpenAI (Limited Support)
|
||||
|
||||
```python
|
||||
# Basic OpenAI support
|
||||
model = outlines.models.openai(
|
||||
"gpt-4o-mini",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
|
||||
# Note: Some features limited with API models
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
### 4. Pydantic Integration
|
||||
|
||||
Outlines has first-class Pydantic support with automatic schema translation.
|
||||
|
||||
#### Basic Models
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Article(BaseModel):
|
||||
title: str = Field(description="Article title")
|
||||
author: str = Field(description="Author name")
|
||||
word_count: int = Field(description="Number of words", gt=0)
|
||||
tags: list[str] = Field(description="List of tags")
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, Article)
|
||||
|
||||
article = generator("Generate article about AI")
|
||||
print(article.title)
|
||||
print(article.word_count) # Guaranteed > 0
|
||||
```
|
||||
|
||||
#### Nested Models
|
||||
|
||||
```python
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
country: str
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
address: Address # Nested model
|
||||
|
||||
generator = outlines.generate.json(model, Person)
|
||||
person = generator("Generate person in New York")
|
||||
|
||||
print(person.address.city) # "New York"
|
||||
```
|
||||
|
||||
#### Enums and Literals
|
||||
|
||||
```python
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
class Status(str, Enum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
class Application(BaseModel):
|
||||
applicant: str
|
||||
status: Status # Must be one of enum values
|
||||
priority: Literal["low", "medium", "high"] # Must be one of literals
|
||||
|
||||
generator = outlines.generate.json(model, Application)
|
||||
app = generator("Generate application")
|
||||
|
||||
print(app.status) # Status.PENDING (or APPROVED/REJECTED)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: Data Extraction
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
import outlines
|
||||
|
||||
class CompanyInfo(BaseModel):
|
||||
name: str
|
||||
founded_year: int
|
||||
industry: str
|
||||
employees: int
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, CompanyInfo)
|
||||
|
||||
text = """
|
||||
Apple Inc. was founded in 1976 in the technology industry.
|
||||
The company employs approximately 164,000 people worldwide.
|
||||
"""
|
||||
|
||||
prompt = f"Extract company information:\n{text}\n\nCompany:"
|
||||
company = generator(prompt)
|
||||
|
||||
print(f"Name: {company.name}")
|
||||
print(f"Founded: {company.founded_year}")
|
||||
print(f"Industry: {company.industry}")
|
||||
print(f"Employees: {company.employees}")
|
||||
```
|
||||
|
||||
### Pattern 2: Classification
|
||||
|
||||
```python
|
||||
from typing import Literal
|
||||
import outlines
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Binary classification
|
||||
generator = outlines.generate.choice(model, ["spam", "not_spam"])
|
||||
result = generator("Email: Buy now! 50% off!")
|
||||
|
||||
# Multi-class classification
|
||||
categories = ["technology", "business", "sports", "entertainment"]
|
||||
category_gen = outlines.generate.choice(model, categories)
|
||||
category = category_gen("Article: Apple announces new iPhone...")
|
||||
|
||||
# With confidence
|
||||
class Classification(BaseModel):
|
||||
label: Literal["positive", "negative", "neutral"]
|
||||
confidence: float
|
||||
|
||||
classifier = outlines.generate.json(model, Classification)
|
||||
result = classifier("Review: This product is okay, nothing special")
|
||||
```
|
||||
|
||||
### Pattern 3: Structured Forms
|
||||
|
||||
```python
|
||||
class UserProfile(BaseModel):
|
||||
full_name: str
|
||||
age: int
|
||||
email: str
|
||||
phone: str
|
||||
country: str
|
||||
interests: list[str]
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, UserProfile)
|
||||
|
||||
prompt = """
|
||||
Extract user profile from:
|
||||
Name: Alice Johnson
|
||||
Age: 28
|
||||
Email: alice@example.com
|
||||
Phone: 555-0123
|
||||
Country: USA
|
||||
Interests: hiking, photography, cooking
|
||||
"""
|
||||
|
||||
profile = generator(prompt)
|
||||
print(profile.full_name)
|
||||
print(profile.interests) # ["hiking", "photography", "cooking"]
|
||||
```
|
||||
|
||||
### Pattern 4: Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
class Entity(BaseModel):
|
||||
name: str
|
||||
type: Literal["PERSON", "ORGANIZATION", "LOCATION"]
|
||||
|
||||
class DocumentEntities(BaseModel):
|
||||
entities: list[Entity]
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, DocumentEntities)
|
||||
|
||||
text = "Tim Cook met with Satya Nadella at Microsoft headquarters in Redmond."
|
||||
prompt = f"Extract entities from: {text}"
|
||||
|
||||
result = generator(prompt)
|
||||
for entity in result.entities:
|
||||
print(f"{entity.name} ({entity.type})")
|
||||
```
|
||||
|
||||
### Pattern 5: Code Generation
|
||||
|
||||
```python
|
||||
class PythonFunction(BaseModel):
|
||||
function_name: str
|
||||
parameters: list[str]
|
||||
docstring: str
|
||||
body: str
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, PythonFunction)
|
||||
|
||||
prompt = "Generate a Python function to calculate factorial"
|
||||
func = generator(prompt)
|
||||
|
||||
print(f"def {func.function_name}({', '.join(func.parameters)}):")
|
||||
print(f' """{func.docstring}"""')
|
||||
print(f" {func.body}")
|
||||
```
|
||||
|
||||
### Pattern 6: Batch Processing
|
||||
|
||||
```python
|
||||
def batch_extract(texts: list[str], schema: type[BaseModel]):
|
||||
"""Extract structured data from multiple texts."""
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
results = []
|
||||
for text in texts:
|
||||
result = generator(f"Extract from: {text}")
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
texts = [
|
||||
"John is 30 years old",
|
||||
"Alice is 25 years old",
|
||||
"Bob is 40 years old"
|
||||
]
|
||||
|
||||
people = batch_extract(texts, Person)
|
||||
for person in people:
|
||||
print(f"{person.name}: {person.age}")
|
||||
```
|
||||
|
||||
## Backend Configuration
|
||||
|
||||
### Transformers
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Basic usage
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# GPU configuration
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda",
|
||||
model_kwargs={"torch_dtype": "float16"}
|
||||
)
|
||||
|
||||
# Popular models
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
model = outlines.models.transformers("Qwen/Qwen2.5-7B-Instruct")
|
||||
```
|
||||
|
||||
### llama.cpp
|
||||
|
||||
```python
|
||||
# Load GGUF model
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/llama-3.1-8b.Q4_K_M.gguf",
|
||||
n_ctx=4096, # Context window
|
||||
n_gpu_layers=35, # GPU layers
|
||||
n_threads=8 # CPU threads
|
||||
)
|
||||
|
||||
# Full GPU offload
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.gguf",
|
||||
n_gpu_layers=-1 # All layers on GPU
|
||||
)
|
||||
```
|
||||
|
||||
### vLLM (Production)
|
||||
|
||||
```python
|
||||
# Single GPU
|
||||
model = outlines.models.vllm("meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
# Multi-GPU
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
tensor_parallel_size=4 # 4 GPUs
|
||||
)
|
||||
|
||||
# With quantization
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
quantization="awq" # Or "gptq"
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Specific Types
|
||||
|
||||
```python
|
||||
# ✅ Good: Specific types
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float # Not str
|
||||
quantity: int # Not str
|
||||
in_stock: bool # Not str
|
||||
|
||||
# ❌ Bad: Everything as string
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: str # Should be float
|
||||
quantity: str # Should be int
|
||||
```
|
||||
|
||||
### 2. Add Constraints
|
||||
|
||||
```python
|
||||
from pydantic import Field
|
||||
|
||||
# ✅ Good: With constraints
|
||||
class User(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=100)
|
||||
age: int = Field(ge=0, le=120)
|
||||
email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")
|
||||
|
||||
# ❌ Bad: No constraints
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
```
|
||||
|
||||
### 3. Use Enums for Categories
|
||||
|
||||
```python
|
||||
# ✅ Good: Enum for fixed set
|
||||
class Priority(str, Enum):
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
|
||||
class Task(BaseModel):
|
||||
title: str
|
||||
priority: Priority
|
||||
|
||||
# ❌ Bad: Free-form string
|
||||
class Task(BaseModel):
|
||||
title: str
|
||||
priority: str # Can be anything
|
||||
```
|
||||
|
||||
### 4. Provide Context in Prompts
|
||||
|
||||
```python
|
||||
# ✅ Good: Clear context
|
||||
prompt = """
|
||||
Extract product information from the following text.
|
||||
Text: iPhone 15 Pro costs $999 and is currently in stock.
|
||||
Product:
|
||||
"""
|
||||
|
||||
# ❌ Bad: Minimal context
|
||||
prompt = "iPhone 15 Pro costs $999 and is currently in stock."
|
||||
```
|
||||
|
||||
### 5. Handle Optional Fields
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
# ✅ Good: Optional fields for incomplete data
|
||||
class Article(BaseModel):
|
||||
title: str # Required
|
||||
author: Optional[str] = None # Optional
|
||||
date: Optional[str] = None # Optional
|
||||
tags: list[str] = [] # Default empty list
|
||||
|
||||
# Can succeed even if author/date missing
|
||||
```
|
||||
|
||||
## Comparison to Alternatives
|
||||
|
||||
| Feature | Outlines | Instructor | Guidance | LMQL |
|
||||
|---------|----------|------------|----------|------|
|
||||
| Pydantic Support | ✅ Native | ✅ Native | ❌ No | ❌ No |
|
||||
| JSON Schema | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes |
|
||||
| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes |
|
||||
| Local Models | ✅ Full | ⚠️ Limited | ✅ Full | ✅ Full |
|
||||
| API Models | ⚠️ Limited | ✅ Full | ✅ Full | ✅ Full |
|
||||
| Zero Overhead | ✅ Yes | ❌ No | ⚠️ Partial | ✅ Yes |
|
||||
| Automatic Retrying | ❌ No | ✅ Yes | ❌ No | ❌ No |
|
||||
| Learning Curve | Low | Low | Low | High |
|
||||
|
||||
**When to choose Outlines:**
|
||||
- Using local models (Transformers, llama.cpp, vLLM)
|
||||
- Need maximum inference speed
|
||||
- Want Pydantic model support
|
||||
- Require zero-overhead structured generation
|
||||
- Control token sampling process
|
||||
|
||||
**When to choose alternatives:**
|
||||
- Instructor: Need API models with automatic retrying
|
||||
- Guidance: Need token healing and complex workflows
|
||||
- LMQL: Prefer declarative query syntax
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
**Speed:**
|
||||
- **Zero overhead**: Structured generation as fast as unconstrained
|
||||
- **Fast-forward optimization**: Skips deterministic tokens
|
||||
- **1.2-2x faster** than post-generation validation approaches
|
||||
|
||||
**Memory:**
|
||||
- FSM compiled once per schema (cached)
|
||||
- Minimal runtime overhead
|
||||
- Efficient with vLLM for high throughput
|
||||
|
||||
**Accuracy:**
|
||||
- **100% valid outputs** (guaranteed by FSM)
|
||||
- No retry loops needed
|
||||
- Deterministic token filtering
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://outlines-dev.github.io/outlines
|
||||
- **GitHub**: https://github.com/outlines-dev/outlines (8k+ stars)
|
||||
- **Discord**: https://discord.gg/R9DSu34mGd
|
||||
- **Blog**: https://blog.dottxt.co
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/json_generation.md` - Comprehensive JSON and Pydantic patterns
|
||||
- `references/backends.md` - Backend-specific configuration
|
||||
- `references/examples.md` - Production-ready examples
|
||||
|
||||
|
||||
615
optional-skills/mlops/inference/outlines/references/backends.md
Normal file
615
optional-skills/mlops/inference/outlines/references/backends.md
Normal file
|
|
@ -0,0 +1,615 @@
|
|||
# Backend Configuration Guide
|
||||
|
||||
Complete guide to configuring Outlines with different model backends.
|
||||
|
||||
## Table of Contents
|
||||
- Local Models (Transformers, llama.cpp, vLLM)
|
||||
- API Models (OpenAI)
|
||||
- Performance Comparison
|
||||
- Configuration Examples
|
||||
- Production Deployment
|
||||
|
||||
## Transformers (Hugging Face)
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Load model from Hugging Face
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Use with generator
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
result = generator("Your prompt")
|
||||
```
|
||||
|
||||
### GPU Configuration
|
||||
|
||||
```python
|
||||
# Use CUDA GPU
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Use specific GPU
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda:0" # GPU 0
|
||||
)
|
||||
|
||||
# Use CPU
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Use Apple Silicon MPS
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="mps"
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```python
|
||||
# FP16 for faster inference
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"torch_dtype": "float16"
|
||||
}
|
||||
)
|
||||
|
||||
# 8-bit quantization (less memory)
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"load_in_8bit": True,
|
||||
"device_map": "auto"
|
||||
}
|
||||
)
|
||||
|
||||
# 4-bit quantization (even less memory)
|
||||
model = outlines.models.transformers(
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"device_map": "auto",
|
||||
"bnb_4bit_compute_dtype": "float16"
|
||||
}
|
||||
)
|
||||
|
||||
# Multi-GPU
|
||||
model = outlines.models.transformers(
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"device_map": "auto", # Automatic GPU distribution
|
||||
"max_memory": {0: "40GB", 1: "40GB"} # Per-GPU limits
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Popular Models
|
||||
|
||||
```python
|
||||
# Phi-4 (Microsoft)
|
||||
model = outlines.models.transformers("microsoft/Phi-4-mini-instruct")
|
||||
model = outlines.models.transformers("microsoft/Phi-3-medium-4k-instruct")
|
||||
|
||||
# Llama 3.1 (Meta)
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-70B-Instruct")
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-405B-Instruct")
|
||||
|
||||
# Mistral (Mistral AI)
|
||||
model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
model = outlines.models.transformers("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
model = outlines.models.transformers("mistralai/Mixtral-8x22B-Instruct-v0.1")
|
||||
|
||||
# Qwen (Alibaba)
|
||||
model = outlines.models.transformers("Qwen/Qwen2.5-7B-Instruct")
|
||||
model = outlines.models.transformers("Qwen/Qwen2.5-14B-Instruct")
|
||||
model = outlines.models.transformers("Qwen/Qwen2.5-72B-Instruct")
|
||||
|
||||
# Gemma (Google)
|
||||
model = outlines.models.transformers("google/gemma-2-9b-it")
|
||||
model = outlines.models.transformers("google/gemma-2-27b-it")
|
||||
|
||||
# Llava (Vision)
|
||||
model = outlines.models.transformers("llava-hf/llava-v1.6-mistral-7b-hf")
|
||||
```
|
||||
|
||||
### Custom Model Loading
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import outlines
|
||||
|
||||
# Load model manually
|
||||
tokenizer = AutoTokenizer.from_pretrained("your-model")
|
||||
model_hf = AutoModelForCausalLM.from_pretrained(
|
||||
"your-model",
|
||||
device_map="auto",
|
||||
torch_dtype="float16"
|
||||
)
|
||||
|
||||
# Use with Outlines
|
||||
model = outlines.models.transformers(
|
||||
model=model_hf,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
```
|
||||
|
||||
## llama.cpp
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Load GGUF model
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/llama-3.1-8b-instruct.Q4_K_M.gguf",
|
||||
n_ctx=4096 # Context window
|
||||
)
|
||||
|
||||
# Use with generator
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
### GPU Configuration
|
||||
|
||||
```python
|
||||
# CPU only
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_threads=8 # Use 8 CPU threads
|
||||
)
|
||||
|
||||
# GPU offload (partial)
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35, # Offload 35 layers to GPU
|
||||
n_threads=4 # CPU threads for remaining layers
|
||||
)
|
||||
|
||||
# Full GPU offload
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.gguf",
|
||||
n_ctx=8192,
|
||||
n_gpu_layers=-1 # All layers on GPU
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```python
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/llama-3.1-8b.Q4_K_M.gguf",
|
||||
n_ctx=8192, # Context window (tokens)
|
||||
n_gpu_layers=35, # GPU layers
|
||||
n_threads=8, # CPU threads
|
||||
n_batch=512, # Batch size for prompt processing
|
||||
use_mmap=True, # Memory-map model file (faster loading)
|
||||
use_mlock=False, # Lock model in RAM (prevents swapping)
|
||||
seed=42, # Random seed for reproducibility
|
||||
verbose=False # Suppress verbose output
|
||||
)
|
||||
```
|
||||
|
||||
### Quantization Formats
|
||||
|
||||
```python
|
||||
# Q4_K_M (4-bit, recommended for most cases)
|
||||
# - Size: ~4.5GB for 7B model
|
||||
# - Quality: Good
|
||||
# - Speed: Fast
|
||||
model = outlines.models.llamacpp("./models/model.Q4_K_M.gguf")
|
||||
|
||||
# Q5_K_M (5-bit, better quality)
|
||||
# - Size: ~5.5GB for 7B model
|
||||
# - Quality: Very good
|
||||
# - Speed: Slightly slower than Q4
|
||||
model = outlines.models.llamacpp("./models/model.Q5_K_M.gguf")
|
||||
|
||||
# Q6_K (6-bit, high quality)
|
||||
# - Size: ~6.5GB for 7B model
|
||||
# - Quality: Excellent
|
||||
# - Speed: Slower than Q5
|
||||
model = outlines.models.llamacpp("./models/model.Q6_K.gguf")
|
||||
|
||||
# Q8_0 (8-bit, near-original quality)
|
||||
# - Size: ~8GB for 7B model
|
||||
# - Quality: Near FP16
|
||||
# - Speed: Slower than Q6
|
||||
model = outlines.models.llamacpp("./models/model.Q8_0.gguf")
|
||||
|
||||
# F16 (16-bit float, original quality)
|
||||
# - Size: ~14GB for 7B model
|
||||
# - Quality: Original
|
||||
# - Speed: Slowest
|
||||
model = outlines.models.llamacpp("./models/model.F16.gguf")
|
||||
```
|
||||
|
||||
### Popular GGUF Models
|
||||
|
||||
```python
|
||||
# Llama 3.1
|
||||
model = outlines.models.llamacpp("llama-3.1-8b-instruct.Q4_K_M.gguf")
|
||||
model = outlines.models.llamacpp("llama-3.1-70b-instruct.Q4_K_M.gguf")
|
||||
|
||||
# Mistral
|
||||
model = outlines.models.llamacpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf")
|
||||
|
||||
# Phi-4
|
||||
model = outlines.models.llamacpp("phi-4-mini-instruct.Q4_K_M.gguf")
|
||||
|
||||
# Qwen
|
||||
model = outlines.models.llamacpp("qwen2.5-7b-instruct.Q4_K_M.gguf")
|
||||
```
|
||||
|
||||
### Apple Silicon Optimization
|
||||
|
||||
```python
|
||||
# Optimized for M1/M2/M3 Macs
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/llama-3.1-8b.Q4_K_M.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=-1, # Use Metal GPU acceleration
|
||||
use_mmap=True, # Efficient memory mapping
|
||||
n_threads=8 # Use performance cores
|
||||
)
|
||||
```
|
||||
|
||||
## vLLM (Production)
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Load model with vLLM
|
||||
model = outlines.models.vllm("meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
# Use with generator
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
### Single GPU
|
||||
|
||||
```python
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
gpu_memory_utilization=0.9, # Use 90% of GPU memory
|
||||
max_model_len=4096 # Max sequence length
|
||||
)
|
||||
```
|
||||
|
||||
### Multi-GPU
|
||||
|
||||
```python
|
||||
# Tensor parallelism (split model across GPUs)
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
tensor_parallel_size=4, # Use 4 GPUs
|
||||
gpu_memory_utilization=0.9
|
||||
)
|
||||
|
||||
# Pipeline parallelism (rare, for very large models)
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-405B-Instruct",
|
||||
pipeline_parallel_size=8, # 8-GPU pipeline
|
||||
tensor_parallel_size=4 # 4-GPU tensor split
|
||||
# Total: 32 GPUs
|
||||
)
|
||||
```
|
||||
|
||||
### Quantization
|
||||
|
||||
```python
|
||||
# AWQ quantization (4-bit)
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
quantization="awq",
|
||||
dtype="float16"
|
||||
)
|
||||
|
||||
# GPTQ quantization (4-bit)
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
quantization="gptq"
|
||||
)
|
||||
|
||||
# SqueezeLLM quantization
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
quantization="squeezellm"
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```python
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=256, # Max concurrent sequences
|
||||
max_num_batched_tokens=8192, # Max tokens per batch
|
||||
dtype="float16",
|
||||
trust_remote_code=True,
|
||||
enforce_eager=False, # Use CUDA graphs (faster)
|
||||
swap_space=4 # CPU swap space (GB)
|
||||
)
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
# vLLM optimized for high-throughput batch processing
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_num_seqs=128 # Process 128 sequences in parallel
|
||||
)
|
||||
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
|
||||
# Process many prompts efficiently
|
||||
prompts = ["prompt1", "prompt2", ..., "prompt100"]
|
||||
results = [generator(p) for p in prompts]
|
||||
# vLLM automatically batches and optimizes
|
||||
```
|
||||
|
||||
## OpenAI (Limited Support)
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Basic OpenAI support
|
||||
model = outlines.models.openai("gpt-4o-mini", api_key="your-api-key")
|
||||
|
||||
# Use with generator
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
result = generator("Your prompt")
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```python
|
||||
model = outlines.models.openai(
|
||||
"gpt-4o-mini",
|
||||
api_key="your-api-key", # Or set OPENAI_API_KEY env var
|
||||
max_tokens=2048,
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
### Available Models
|
||||
|
||||
```python
|
||||
# GPT-4o (latest)
|
||||
model = outlines.models.openai("gpt-4o")
|
||||
|
||||
# GPT-4o Mini (cost-effective)
|
||||
model = outlines.models.openai("gpt-4o-mini")
|
||||
|
||||
# GPT-4 Turbo
|
||||
model = outlines.models.openai("gpt-4-turbo")
|
||||
|
||||
# GPT-3.5 Turbo
|
||||
model = outlines.models.openai("gpt-3.5-turbo")
|
||||
```
|
||||
|
||||
**Note**: OpenAI support is limited compared to local models. Some advanced features may not work.
|
||||
|
||||
## Backend Comparison
|
||||
|
||||
### Feature Matrix
|
||||
|
||||
| Feature | Transformers | llama.cpp | vLLM | OpenAI |
|
||||
|---------|-------------|-----------|------|--------|
|
||||
| Structured Generation | ✅ Full | ✅ Full | ✅ Full | ⚠️ Limited |
|
||||
| FSM Optimization | ✅ Yes | ✅ Yes | ✅ Yes | ❌ No |
|
||||
| GPU Support | ✅ Yes | ✅ Yes | ✅ Yes | N/A |
|
||||
| Multi-GPU | ✅ Yes | ✅ Yes | ✅ Yes | N/A |
|
||||
| Quantization | ✅ Yes | ✅ Yes | ✅ Yes | N/A |
|
||||
| High Throughput | ⚠️ Medium | ⚠️ Medium | ✅ Excellent | ⚠️ API-limited |
|
||||
| Setup Difficulty | Easy | Medium | Medium | Easy |
|
||||
| Cost | Hardware | Hardware | Hardware | API usage |
|
||||
|
||||
### Performance Characteristics
|
||||
|
||||
**Transformers:**
|
||||
- **Latency**: 50-200ms (single request, GPU)
|
||||
- **Throughput**: 10-50 tokens/sec (depends on hardware)
|
||||
- **Memory**: 2-4GB per 1B parameters (FP16)
|
||||
- **Best for**: Development, small-scale deployment, flexibility
|
||||
|
||||
**llama.cpp:**
|
||||
- **Latency**: 30-150ms (single request)
|
||||
- **Throughput**: 20-150 tokens/sec (depends on quantization)
|
||||
- **Memory**: 0.5-2GB per 1B parameters (Q4-Q8)
|
||||
- **Best for**: CPU inference, Apple Silicon, edge deployment, low memory
|
||||
|
||||
**vLLM:**
|
||||
- **Latency**: 30-100ms (single request)
|
||||
- **Throughput**: 100-1000+ tokens/sec (batch processing)
|
||||
- **Memory**: 2-4GB per 1B parameters (FP16)
|
||||
- **Best for**: Production, high-throughput, batch processing, serving
|
||||
|
||||
**OpenAI:**
|
||||
- **Latency**: 200-500ms (API call)
|
||||
- **Throughput**: API rate limits
|
||||
- **Memory**: N/A (cloud-based)
|
||||
- **Best for**: Quick prototyping, no infrastructure
|
||||
|
||||
### Memory Requirements
|
||||
|
||||
**7B Model:**
|
||||
- FP16: ~14GB
|
||||
- 8-bit: ~7GB
|
||||
- 4-bit: ~4GB
|
||||
- Q4_K_M (GGUF): ~4.5GB
|
||||
|
||||
**13B Model:**
|
||||
- FP16: ~26GB
|
||||
- 8-bit: ~13GB
|
||||
- 4-bit: ~7GB
|
||||
- Q4_K_M (GGUF): ~8GB
|
||||
|
||||
**70B Model:**
|
||||
- FP16: ~140GB (multi-GPU)
|
||||
- 8-bit: ~70GB (multi-GPU)
|
||||
- 4-bit: ~35GB (single A100/H100)
|
||||
- Q4_K_M (GGUF): ~40GB
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Transformers Optimization
|
||||
|
||||
```python
|
||||
# Use FP16
|
||||
model = outlines.models.transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
model_kwargs={"torch_dtype": "float16"}
|
||||
)
|
||||
|
||||
# Use flash attention (2-4x faster)
|
||||
model = outlines.models.transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"torch_dtype": "float16",
|
||||
"use_flash_attention_2": True
|
||||
}
|
||||
)
|
||||
|
||||
# Use 8-bit quantization (2x less memory)
|
||||
model = outlines.models.transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"load_in_8bit": True,
|
||||
"device_map": "auto"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### llama.cpp Optimization
|
||||
|
||||
```python
|
||||
# Maximize GPU usage
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.Q4_K_M.gguf",
|
||||
n_gpu_layers=-1, # All layers on GPU
|
||||
n_ctx=8192,
|
||||
n_batch=512 # Larger batch = faster
|
||||
)
|
||||
|
||||
# Optimize for CPU (Apple Silicon)
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.Q4_K_M.gguf",
|
||||
n_ctx=4096,
|
||||
n_threads=8, # Use all performance cores
|
||||
use_mmap=True
|
||||
)
|
||||
```
|
||||
|
||||
### vLLM Optimization
|
||||
|
||||
```python
|
||||
# High throughput
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
gpu_memory_utilization=0.95, # Use 95% of GPU
|
||||
max_num_seqs=256, # High concurrency
|
||||
enforce_eager=False # Use CUDA graphs
|
||||
)
|
||||
|
||||
# Multi-GPU
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
tensor_parallel_size=4, # 4 GPUs
|
||||
gpu_memory_utilization=0.9
|
||||
)
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
### Docker with vLLM
|
||||
|
||||
```dockerfile
|
||||
FROM vllm/vllm-openai:latest
|
||||
|
||||
# Install outlines
|
||||
RUN pip install outlines
|
||||
|
||||
# Copy your code
|
||||
COPY app.py /app/
|
||||
|
||||
# Run
|
||||
CMD ["python", "/app/app.py"]
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Transformers cache
|
||||
export HF_HOME="/path/to/cache"
|
||||
export TRANSFORMERS_CACHE="/path/to/cache"
|
||||
|
||||
# GPU selection
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
# OpenAI API key
|
||||
export OPENAI_API_KEY="sk-..."
|
||||
|
||||
# Disable tokenizers parallelism warning
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
```
|
||||
|
||||
### Model Serving
|
||||
|
||||
```python
|
||||
# Simple HTTP server with vLLM
|
||||
import outlines
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model once at startup
|
||||
model = outlines.models.vllm("meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
generator = outlines.generate.json(model, User)
|
||||
|
||||
@app.post("/extract")
|
||||
def extract(text: str):
|
||||
result = generator(f"Extract user from: {text}")
|
||||
return result.model_dump()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Transformers**: https://huggingface.co/docs/transformers
|
||||
- **llama.cpp**: https://github.com/ggerganov/llama.cpp
|
||||
- **vLLM**: https://docs.vllm.ai
|
||||
- **Outlines**: https://github.com/outlines-dev/outlines
|
||||
773
optional-skills/mlops/inference/outlines/references/examples.md
Normal file
773
optional-skills/mlops/inference/outlines/references/examples.md
Normal file
|
|
@ -0,0 +1,773 @@
|
|||
# Production-Ready Examples
|
||||
|
||||
Real-world examples of using Outlines for structured generation in production systems.
|
||||
|
||||
## Table of Contents
|
||||
- Data Extraction
|
||||
- Classification Systems
|
||||
- Form Processing
|
||||
- Multi-Entity Extraction
|
||||
- Code Generation
|
||||
- Batch Processing
|
||||
- Production Patterns
|
||||
|
||||
## Data Extraction
|
||||
|
||||
### Basic Information Extraction
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
import outlines
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(ge=0, le=120)
|
||||
occupation: str
|
||||
email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")
|
||||
location: str
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, PersonInfo)
|
||||
|
||||
text = """
|
||||
Dr. Sarah Johnson is a 42-year-old research scientist at MIT.
|
||||
She can be reached at sarah.j@mit.edu and currently lives in Cambridge, MA.
|
||||
"""
|
||||
|
||||
prompt = f"Extract person information from:\n{text}\n\nPerson:"
|
||||
person = generator(prompt)
|
||||
|
||||
print(f"Name: {person.name}")
|
||||
print(f"Age: {person.age}")
|
||||
print(f"Occupation: {person.occupation}")
|
||||
print(f"Email: {person.email}")
|
||||
print(f"Location: {person.location}")
|
||||
```
|
||||
|
||||
### Company Information
|
||||
|
||||
```python
|
||||
class CompanyInfo(BaseModel):
|
||||
name: str
|
||||
founded_year: int = Field(ge=1800, le=2025)
|
||||
industry: str
|
||||
headquarters: str
|
||||
employees: int = Field(gt=0)
|
||||
revenue: Optional[str] = None
|
||||
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
generator = outlines.generate.json(model, CompanyInfo)
|
||||
|
||||
text = """
|
||||
Tesla, Inc. was founded in 2003 and operates primarily in the automotive
|
||||
and energy industries. The company is headquartered in Austin, Texas,
|
||||
and employs approximately 140,000 people worldwide.
|
||||
"""
|
||||
|
||||
company = generator(f"Extract company information:\n{text}\n\nCompany:")
|
||||
|
||||
print(f"Company: {company.name}")
|
||||
print(f"Founded: {company.founded_year}")
|
||||
print(f"Industry: {company.industry}")
|
||||
print(f"HQ: {company.headquarters}")
|
||||
print(f"Employees: {company.employees:,}")
|
||||
```
|
||||
|
||||
### Product Specifications
|
||||
|
||||
```python
|
||||
class ProductSpec(BaseModel):
|
||||
name: str
|
||||
brand: str
|
||||
price: float = Field(gt=0)
|
||||
dimensions: str
|
||||
weight: str
|
||||
features: list[str]
|
||||
rating: Optional[float] = Field(None, ge=0, le=5)
|
||||
|
||||
generator = outlines.generate.json(model, ProductSpec)
|
||||
|
||||
text = """
|
||||
The Apple iPhone 15 Pro is priced at $999. It measures 146.6 x 70.6 x 8.25 mm
|
||||
and weighs 187 grams. Key features include the A17 Pro chip, titanium design,
|
||||
action button, and USB-C port. It has an average customer rating of 4.5 stars.
|
||||
"""
|
||||
|
||||
product = generator(f"Extract product specifications:\n{text}\n\nProduct:")
|
||||
|
||||
print(f"Product: {product.brand} {product.name}")
|
||||
print(f"Price: ${product.price}")
|
||||
print(f"Features: {', '.join(product.features)}")
|
||||
```
|
||||
|
||||
## Classification Systems
|
||||
|
||||
### Sentiment Analysis
|
||||
|
||||
```python
|
||||
from typing import Literal
|
||||
from enum import Enum
|
||||
|
||||
class Sentiment(str, Enum):
|
||||
VERY_POSITIVE = "very_positive"
|
||||
POSITIVE = "positive"
|
||||
NEUTRAL = "neutral"
|
||||
NEGATIVE = "negative"
|
||||
VERY_NEGATIVE = "very_negative"
|
||||
|
||||
class SentimentAnalysis(BaseModel):
|
||||
text: str
|
||||
sentiment: Sentiment
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
aspects: list[str] # What aspects were mentioned
|
||||
reasoning: str
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, SentimentAnalysis)
|
||||
|
||||
review = """
|
||||
This product completely exceeded my expectations! The build quality is
|
||||
outstanding, and customer service was incredibly helpful. My only minor
|
||||
complaint is the packaging could be better.
|
||||
"""
|
||||
|
||||
result = generator(f"Analyze sentiment:\n{review}\n\nAnalysis:")
|
||||
|
||||
print(f"Sentiment: {result.sentiment.value}")
|
||||
print(f"Confidence: {result.confidence:.2%}")
|
||||
print(f"Aspects: {', '.join(result.aspects)}")
|
||||
print(f"Reasoning: {result.reasoning}")
|
||||
```
|
||||
|
||||
### Content Classification
|
||||
|
||||
```python
|
||||
class Category(str, Enum):
|
||||
TECHNOLOGY = "technology"
|
||||
BUSINESS = "business"
|
||||
SCIENCE = "science"
|
||||
POLITICS = "politics"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
SPORTS = "sports"
|
||||
HEALTH = "health"
|
||||
|
||||
class ArticleClassification(BaseModel):
|
||||
primary_category: Category
|
||||
secondary_categories: list[Category]
|
||||
keywords: list[str] = Field(min_items=3, max_items=10)
|
||||
target_audience: Literal["general", "expert", "beginner"]
|
||||
reading_level: Literal["elementary", "intermediate", "advanced"]
|
||||
|
||||
generator = outlines.generate.json(model, ArticleClassification)
|
||||
|
||||
article = """
|
||||
Apple announced groundbreaking advancements in its AI capabilities with the
|
||||
release of iOS 18. The new features leverage machine learning to significantly
|
||||
improve battery life and overall device performance. Industry analysts predict
|
||||
this will strengthen Apple's position in the competitive smartphone market.
|
||||
"""
|
||||
|
||||
classification = generator(f"Classify article:\n{article}\n\nClassification:")
|
||||
|
||||
print(f"Primary: {classification.primary_category.value}")
|
||||
print(f"Secondary: {[c.value for c in classification.secondary_categories]}")
|
||||
print(f"Keywords: {classification.keywords}")
|
||||
print(f"Audience: {classification.target_audience}")
|
||||
```
|
||||
|
||||
### Intent Recognition
|
||||
|
||||
```python
|
||||
class Intent(str, Enum):
|
||||
QUESTION = "question"
|
||||
COMPLAINT = "complaint"
|
||||
REQUEST = "request"
|
||||
FEEDBACK = "feedback"
|
||||
CANCEL = "cancel"
|
||||
UPGRADE = "upgrade"
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
original_message: str
|
||||
intent: Intent
|
||||
urgency: Literal["low", "medium", "high", "critical"]
|
||||
department: Literal["support", "sales", "billing", "technical"]
|
||||
sentiment: Literal["positive", "neutral", "negative"]
|
||||
action_required: bool
|
||||
summary: str
|
||||
|
||||
generator = outlines.generate.json(model, UserMessage)
|
||||
|
||||
message = """
|
||||
I've been charged twice for my subscription this month! This is the third
|
||||
time this has happened. I need someone to fix this immediately and refund
|
||||
the extra charge. Very disappointed with this service.
|
||||
"""
|
||||
|
||||
result = generator(f"Analyze message:\n{message}\n\nAnalysis:")
|
||||
|
||||
print(f"Intent: {result.intent.value}")
|
||||
print(f"Urgency: {result.urgency}")
|
||||
print(f"Route to: {result.department}")
|
||||
print(f"Action required: {result.action_required}")
|
||||
print(f"Summary: {result.summary}")
|
||||
```
|
||||
|
||||
## Form Processing
|
||||
|
||||
### Job Application
|
||||
|
||||
```python
|
||||
class Education(BaseModel):
|
||||
degree: str
|
||||
field: str
|
||||
institution: str
|
||||
year: int
|
||||
|
||||
class Experience(BaseModel):
|
||||
title: str
|
||||
company: str
|
||||
duration: str
|
||||
responsibilities: list[str]
|
||||
|
||||
class JobApplication(BaseModel):
|
||||
full_name: str
|
||||
email: str
|
||||
phone: str
|
||||
education: list[Education]
|
||||
experience: list[Experience]
|
||||
skills: list[str]
|
||||
availability: str
|
||||
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
generator = outlines.generate.json(model, JobApplication)
|
||||
|
||||
resume_text = """
|
||||
John Smith
|
||||
Email: john.smith@email.com | Phone: 555-0123
|
||||
|
||||
EDUCATION
|
||||
- BS in Computer Science, MIT, 2018
|
||||
- MS in Artificial Intelligence, Stanford, 2020
|
||||
|
||||
EXPERIENCE
|
||||
Software Engineer, Google (2020-2023)
|
||||
- Developed ML pipelines for search ranking
|
||||
- Led team of 5 engineers
|
||||
- Improved search quality by 15%
|
||||
|
||||
SKILLS: Python, Machine Learning, TensorFlow, System Design
|
||||
|
||||
AVAILABILITY: Immediate
|
||||
"""
|
||||
|
||||
application = generator(f"Extract job application:\n{resume_text}\n\nApplication:")
|
||||
|
||||
print(f"Applicant: {application.full_name}")
|
||||
print(f"Email: {application.email}")
|
||||
print(f"Education: {len(application.education)} degrees")
|
||||
for edu in application.education:
|
||||
print(f" - {edu.degree} in {edu.field}, {edu.institution} ({edu.year})")
|
||||
print(f"Experience: {len(application.experience)} positions")
|
||||
```
|
||||
|
||||
### Invoice Processing
|
||||
|
||||
```python
|
||||
class InvoiceItem(BaseModel):
|
||||
description: str
|
||||
quantity: int = Field(gt=0)
|
||||
unit_price: float = Field(gt=0)
|
||||
total: float = Field(gt=0)
|
||||
|
||||
class Invoice(BaseModel):
|
||||
invoice_number: str
|
||||
date: str = Field(pattern=r"\d{4}-\d{2}-\d{2}")
|
||||
vendor: str
|
||||
customer: str
|
||||
items: list[InvoiceItem]
|
||||
subtotal: float = Field(gt=0)
|
||||
tax: float = Field(ge=0)
|
||||
total: float = Field(gt=0)
|
||||
|
||||
generator = outlines.generate.json(model, Invoice)
|
||||
|
||||
invoice_text = """
|
||||
INVOICE #INV-2024-001
|
||||
Date: 2024-01-15
|
||||
|
||||
From: Acme Corp
|
||||
To: Smith & Co
|
||||
|
||||
Items:
|
||||
- Widget A: 10 units @ $50.00 = $500.00
|
||||
- Widget B: 5 units @ $75.00 = $375.00
|
||||
- Service Fee: 1 @ $100.00 = $100.00
|
||||
|
||||
Subtotal: $975.00
|
||||
Tax (8%): $78.00
|
||||
TOTAL: $1,053.00
|
||||
"""
|
||||
|
||||
invoice = generator(f"Extract invoice:\n{invoice_text}\n\nInvoice:")
|
||||
|
||||
print(f"Invoice: {invoice.invoice_number}")
|
||||
print(f"From: {invoice.vendor} → To: {invoice.customer}")
|
||||
print(f"Items: {len(invoice.items)}")
|
||||
for item in invoice.items:
|
||||
print(f" - {item.description}: {item.quantity} × ${item.unit_price} = ${item.total}")
|
||||
print(f"Total: ${invoice.total}")
|
||||
```
|
||||
|
||||
### Survey Responses
|
||||
|
||||
```python
|
||||
class SurveyResponse(BaseModel):
|
||||
respondent_id: str
|
||||
completion_date: str
|
||||
satisfaction: Literal[1, 2, 3, 4, 5]
|
||||
would_recommend: bool
|
||||
favorite_features: list[str]
|
||||
improvement_areas: list[str]
|
||||
additional_comments: Optional[str] = None
|
||||
|
||||
generator = outlines.generate.json(model, SurveyResponse)
|
||||
|
||||
survey_text = """
|
||||
Survey ID: RESP-12345
|
||||
Completed: 2024-01-20
|
||||
|
||||
How satisfied are you with our product? 4 out of 5
|
||||
|
||||
Would you recommend to a friend? Yes
|
||||
|
||||
What features do you like most?
|
||||
- Fast performance
|
||||
- Easy to use
|
||||
- Great customer support
|
||||
|
||||
What could we improve?
|
||||
- Better documentation
|
||||
- More integrations
|
||||
|
||||
Additional feedback: Overall great product, keep up the good work!
|
||||
"""
|
||||
|
||||
response = generator(f"Extract survey response:\n{survey_text}\n\nResponse:")
|
||||
|
||||
print(f"Respondent: {response.respondent_id}")
|
||||
print(f"Satisfaction: {response.satisfaction}/5")
|
||||
print(f"Would recommend: {response.would_recommend}")
|
||||
print(f"Favorite features: {response.favorite_features}")
|
||||
print(f"Improvement areas: {response.improvement_areas}")
|
||||
```
|
||||
|
||||
## Multi-Entity Extraction
|
||||
|
||||
### News Article Entities
|
||||
|
||||
```python
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
role: Optional[str] = None
|
||||
affiliation: Optional[str] = None
|
||||
|
||||
class Organization(BaseModel):
|
||||
name: str
|
||||
type: Optional[str] = None
|
||||
|
||||
class Location(BaseModel):
|
||||
name: str
|
||||
type: Literal["city", "state", "country", "region"]
|
||||
|
||||
class Event(BaseModel):
|
||||
name: str
|
||||
date: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
|
||||
class ArticleEntities(BaseModel):
|
||||
people: list[Person]
|
||||
organizations: list[Organization]
|
||||
locations: list[Location]
|
||||
events: list[Event]
|
||||
dates: list[str]
|
||||
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
generator = outlines.generate.json(model, ArticleEntities)
|
||||
|
||||
article = """
|
||||
Apple CEO Tim Cook met with Microsoft CEO Satya Nadella at Microsoft
|
||||
headquarters in Redmond, Washington on September 15, 2024, to discuss
|
||||
potential collaboration opportunities. The meeting was attended by executives
|
||||
from both companies and focused on AI integration strategies. Apple's
|
||||
Cupertino offices will host a follow-up meeting on October 20, 2024.
|
||||
"""
|
||||
|
||||
entities = generator(f"Extract all entities:\n{article}\n\nEntities:")
|
||||
|
||||
print("People:")
|
||||
for person in entities.people:
|
||||
print(f" - {person.name} ({person.role}) @ {person.affiliation}")
|
||||
|
||||
print("\nOrganizations:")
|
||||
for org in entities.organizations:
|
||||
print(f" - {org.name} ({org.type})")
|
||||
|
||||
print("\nLocations:")
|
||||
for loc in entities.locations:
|
||||
print(f" - {loc.name} ({loc.type})")
|
||||
|
||||
print("\nEvents:")
|
||||
for event in entities.events:
|
||||
print(f" - {event.name} on {event.date}")
|
||||
```
|
||||
|
||||
### Document Metadata
|
||||
|
||||
```python
|
||||
class Author(BaseModel):
|
||||
name: str
|
||||
email: Optional[str] = None
|
||||
affiliation: Optional[str] = None
|
||||
|
||||
class Reference(BaseModel):
|
||||
title: str
|
||||
authors: list[str]
|
||||
year: int
|
||||
source: str
|
||||
|
||||
class DocumentMetadata(BaseModel):
|
||||
title: str
|
||||
authors: list[Author]
|
||||
abstract: str
|
||||
keywords: list[str]
|
||||
publication_date: str
|
||||
journal: str
|
||||
doi: Optional[str] = None
|
||||
references: list[Reference]
|
||||
|
||||
generator = outlines.generate.json(model, DocumentMetadata)
|
||||
|
||||
paper = """
|
||||
Title: Advances in Neural Machine Translation
|
||||
|
||||
Authors:
|
||||
- Dr. Jane Smith (jane@university.edu), MIT
|
||||
- Prof. John Doe (jdoe@stanford.edu), Stanford University
|
||||
|
||||
Abstract: This paper presents novel approaches to neural machine translation
|
||||
using transformer architectures. We demonstrate significant improvements in
|
||||
translation quality across multiple language pairs.
|
||||
|
||||
Keywords: Neural Networks, Machine Translation, Transformers, NLP
|
||||
|
||||
Published: Journal of AI Research, 2024-03-15
|
||||
DOI: 10.1234/jair.2024.001
|
||||
|
||||
References:
|
||||
1. "Attention Is All You Need" by Vaswani et al., 2017, NeurIPS
|
||||
2. "BERT: Pre-training of Deep Bidirectional Transformers" by Devlin et al., 2019, NAACL
|
||||
"""
|
||||
|
||||
metadata = generator(f"Extract document metadata:\n{paper}\n\nMetadata:")
|
||||
|
||||
print(f"Title: {metadata.title}")
|
||||
print(f"Authors: {', '.join(a.name for a in metadata.authors)}")
|
||||
print(f"Keywords: {', '.join(metadata.keywords)}")
|
||||
print(f"References: {len(metadata.references)}")
|
||||
```
|
||||
|
||||
## Code Generation
|
||||
|
||||
### Python Function Generation
|
||||
|
||||
```python
|
||||
class Parameter(BaseModel):
|
||||
name: str = Field(pattern=r"^[a-z_][a-z0-9_]*$")
|
||||
type_hint: str
|
||||
default: Optional[str] = None
|
||||
|
||||
class PythonFunction(BaseModel):
|
||||
function_name: str = Field(pattern=r"^[a-z_][a-z0-9_]*$")
|
||||
parameters: list[Parameter]
|
||||
return_type: str
|
||||
docstring: str
|
||||
body: list[str] # Lines of code
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, PythonFunction)
|
||||
|
||||
spec = "Create a function to calculate the factorial of a number"
|
||||
|
||||
func = generator(f"Generate Python function:\n{spec}\n\nFunction:")
|
||||
|
||||
print(f"def {func.function_name}(", end="")
|
||||
print(", ".join(f"{p.name}: {p.type_hint}" for p in func.parameters), end="")
|
||||
print(f") -> {func.return_type}:")
|
||||
print(f' """{func.docstring}"""')
|
||||
for line in func.body:
|
||||
print(f" {line}")
|
||||
```
|
||||
|
||||
### SQL Query Generation
|
||||
|
||||
```python
|
||||
class SQLQuery(BaseModel):
|
||||
query_type: Literal["SELECT", "INSERT", "UPDATE", "DELETE"]
|
||||
select_columns: Optional[list[str]] = None
|
||||
from_tables: list[str]
|
||||
joins: Optional[list[str]] = None
|
||||
where_conditions: Optional[list[str]] = None
|
||||
group_by: Optional[list[str]] = None
|
||||
order_by: Optional[list[str]] = None
|
||||
limit: Optional[int] = None
|
||||
|
||||
generator = outlines.generate.json(model, SQLQuery)
|
||||
|
||||
request = "Get top 10 users who made purchases in the last 30 days, ordered by total spent"
|
||||
|
||||
sql = generator(f"Generate SQL query:\n{request}\n\nQuery:")
|
||||
|
||||
print(f"Query type: {sql.query_type}")
|
||||
print(f"SELECT {', '.join(sql.select_columns)}")
|
||||
print(f"FROM {', '.join(sql.from_tables)}")
|
||||
if sql.joins:
|
||||
for join in sql.joins:
|
||||
print(f" {join}")
|
||||
if sql.where_conditions:
|
||||
print(f"WHERE {' AND '.join(sql.where_conditions)}")
|
||||
if sql.order_by:
|
||||
print(f"ORDER BY {', '.join(sql.order_by)}")
|
||||
if sql.limit:
|
||||
print(f"LIMIT {sql.limit}")
|
||||
```
|
||||
|
||||
### API Endpoint Spec
|
||||
|
||||
```python
|
||||
class Parameter(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
required: bool
|
||||
description: str
|
||||
|
||||
class APIEndpoint(BaseModel):
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
|
||||
path: str
|
||||
description: str
|
||||
parameters: list[Parameter]
|
||||
request_body: Optional[dict] = None
|
||||
response_schema: dict
|
||||
status_codes: dict[int, str]
|
||||
|
||||
generator = outlines.generate.json(model, APIEndpoint)
|
||||
|
||||
spec = "Create user endpoint"
|
||||
|
||||
endpoint = generator(f"Generate API endpoint:\n{spec}\n\nEndpoint:")
|
||||
|
||||
print(f"{endpoint.method} {endpoint.path}")
|
||||
print(f"Description: {endpoint.description}")
|
||||
print("\nParameters:")
|
||||
for param in endpoint.parameters:
|
||||
req = "required" if param.required else "optional"
|
||||
print(f" - {param.name} ({param.type}, {req}): {param.description}")
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
### Parallel Extraction
|
||||
|
||||
```python
|
||||
def batch_extract(texts: list[str], schema: type[BaseModel], model_name: str):
|
||||
"""Extract structured data from multiple texts."""
|
||||
model = outlines.models.transformers(model_name)
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
results = []
|
||||
for i, text in enumerate(texts):
|
||||
print(f"Processing {i+1}/{len(texts)}...", end="\r")
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
category: str
|
||||
|
||||
texts = [
|
||||
"iPhone 15 Pro costs $999 in Electronics",
|
||||
"Running Shoes are $89.99 in Sports",
|
||||
"Coffee Maker priced at $49.99 in Home & Kitchen"
|
||||
]
|
||||
|
||||
products = batch_extract(texts, Product, "microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
for product in products:
|
||||
print(f"{product.name}: ${product.price} ({product.category})")
|
||||
```
|
||||
|
||||
### CSV Processing
|
||||
|
||||
```python
|
||||
import csv
|
||||
|
||||
def process_csv(csv_file: str, schema: type[BaseModel]):
|
||||
"""Process CSV file and extract structured data."""
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
results = []
|
||||
with open(csv_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
text = " | ".join(f"{k}: {v}" for k, v in row.items())
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
class Customer(BaseModel):
|
||||
name: str
|
||||
email: str
|
||||
tier: Literal["basic", "premium", "enterprise"]
|
||||
mrr: float
|
||||
|
||||
# customers = process_csv("customers.csv", Customer)
|
||||
```
|
||||
|
||||
## Production Patterns
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
def safe_extract(text: str, schema: type[BaseModel], retries: int = 3):
|
||||
"""Extract with error handling and retries."""
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
return result
|
||||
except ValidationError as e:
|
||||
print(f"Attempt {attempt + 1} failed: {e}")
|
||||
if attempt == retries - 1:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
if attempt == retries - 1:
|
||||
raise
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
import hashlib
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def cached_extract(text_hash: str, schema_name: str):
|
||||
"""Cache extraction results."""
|
||||
# This would be called with actual extraction logic
|
||||
pass
|
||||
|
||||
def extract_with_cache(text: str, schema: type[BaseModel]):
|
||||
"""Extract with caching."""
|
||||
text_hash = hashlib.md5(text.encode()).hexdigest()
|
||||
schema_name = schema.__name__
|
||||
|
||||
cached_result = cached_extract(text_hash, schema_name)
|
||||
if cached_result:
|
||||
return cached_result
|
||||
|
||||
# Perform actual extraction
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
### Monitoring
|
||||
|
||||
```python
|
||||
import time
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def monitored_extract(text: str, schema: type[BaseModel]):
|
||||
"""Extract with monitoring and logging."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Extraction succeeded in {elapsed:.2f}s")
|
||||
logger.info(f"Input length: {len(text)} chars")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
logger.error(f"Extraction failed after {elapsed:.2f}s: {e}")
|
||||
raise
|
||||
```
|
||||
|
||||
### Rate Limiting
|
||||
|
||||
```python
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, max_requests: int, time_window: int):
|
||||
self.max_requests = max_requests
|
||||
self.time_window = time_window
|
||||
self.requests = []
|
||||
self.lock = Lock()
|
||||
|
||||
def wait_if_needed(self):
|
||||
with self.lock:
|
||||
now = time.time()
|
||||
# Remove old requests
|
||||
self.requests = [r for r in self.requests if now - r < self.time_window]
|
||||
|
||||
if len(self.requests) >= self.max_requests:
|
||||
sleep_time = self.time_window - (now - self.requests[0])
|
||||
time.sleep(sleep_time)
|
||||
self.requests = []
|
||||
|
||||
self.requests.append(now)
|
||||
|
||||
def rate_limited_extract(texts: list[str], schema: type[BaseModel]):
|
||||
"""Extract with rate limiting."""
|
||||
limiter = RateLimiter(max_requests=10, time_window=60) # 10 req/min
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
results = []
|
||||
for text in texts:
|
||||
limiter.wait_if_needed()
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Outlines Documentation**: https://outlines-dev.github.io/outlines
|
||||
- **Pydantic Documentation**: https://docs.pydantic.dev
|
||||
- **GitHub Examples**: https://github.com/outlines-dev/outlines/tree/main/examples
|
||||
|
|
@ -0,0 +1,652 @@
|
|||
# Comprehensive JSON Generation Guide
|
||||
|
||||
Complete guide to JSON generation with Outlines using Pydantic models and JSON schemas.
|
||||
|
||||
## Table of Contents
|
||||
- Pydantic Models
|
||||
- JSON Schema Support
|
||||
- Advanced Patterns
|
||||
- Nested Structures
|
||||
- Complex Types
|
||||
- Validation
|
||||
- Performance Optimization
|
||||
|
||||
## Pydantic Models
|
||||
|
||||
### Basic Models
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
import outlines
|
||||
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, User)
|
||||
|
||||
user = generator("Generate user: Alice, 25, alice@example.com")
|
||||
print(user.name) # "Alice"
|
||||
print(user.age) # 25
|
||||
print(user.email) # "alice@example.com"
|
||||
```
|
||||
|
||||
###
|
||||
|
||||
Field Constraints
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=100)
|
||||
price: float = Field(gt=0, description="Price in USD")
|
||||
discount: float = Field(ge=0, le=100, description="Discount percentage")
|
||||
quantity: int = Field(ge=0, description="Available quantity")
|
||||
sku: str = Field(pattern=r"^[A-Z]{3}-\d{6}$")
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, Product)
|
||||
|
||||
product = generator("Generate product: iPhone 15, $999")
|
||||
# All fields guaranteed to meet constraints
|
||||
```
|
||||
|
||||
**Available Constraints:**
|
||||
- `min_length`, `max_length`: String length
|
||||
- `gt`, `ge`, `lt`, `le`: Numeric comparisons
|
||||
- `multiple_of`: Number must be multiple of value
|
||||
- `pattern`: Regex pattern for strings
|
||||
- `min_items`, `max_items`: List length
|
||||
|
||||
### Optional Fields
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
class Article(BaseModel):
|
||||
title: str # Required
|
||||
author: Optional[str] = None # Optional
|
||||
published_date: Optional[str] = None # Optional
|
||||
tags: list[str] = [] # Default empty list
|
||||
view_count: int = 0 # Default value
|
||||
|
||||
generator = outlines.generate.json(model, Article)
|
||||
|
||||
# Can generate even if optional fields missing
|
||||
article = generator("Title: Introduction to AI")
|
||||
print(article.author) # None (not provided)
|
||||
print(article.tags) # [] (default)
|
||||
```
|
||||
|
||||
### Default Values
|
||||
|
||||
```python
|
||||
class Config(BaseModel):
|
||||
debug: bool = False
|
||||
max_retries: int = 3
|
||||
timeout: float = 30.0
|
||||
log_level: str = "INFO"
|
||||
|
||||
# Generator uses defaults when not specified
|
||||
generator = outlines.generate.json(model, Config)
|
||||
config = generator("Generate config with debug enabled")
|
||||
print(config.debug) # True (from prompt)
|
||||
print(config.timeout) # 30.0 (default)
|
||||
```
|
||||
|
||||
## Enums and Literals
|
||||
|
||||
### Enum Fields
|
||||
|
||||
```python
|
||||
from enum import Enum
|
||||
|
||||
class Status(str, Enum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
class Application(BaseModel):
|
||||
applicant_name: str
|
||||
status: Status # Must be one of enum values
|
||||
submitted_date: str
|
||||
|
||||
generator = outlines.generate.json(model, Application)
|
||||
app = generator("Generate application for John Doe")
|
||||
|
||||
print(app.status) # Status.PENDING (or one of the enum values)
|
||||
print(type(app.status)) # <enum 'Status'>
|
||||
```
|
||||
|
||||
### Literal Types
|
||||
|
||||
```python
|
||||
from typing import Literal
|
||||
|
||||
class Task(BaseModel):
|
||||
title: str
|
||||
priority: Literal["low", "medium", "high", "critical"]
|
||||
status: Literal["todo", "in_progress", "done"]
|
||||
assigned_to: str
|
||||
|
||||
generator = outlines.generate.json(model, Task)
|
||||
task = generator("Create high priority task: Fix bug")
|
||||
|
||||
print(task.priority) # One of: "low", "medium", "high", "critical"
|
||||
```
|
||||
|
||||
### Multiple Choice Fields
|
||||
|
||||
```python
|
||||
class Survey(BaseModel):
|
||||
question: str
|
||||
answer: Literal["strongly_disagree", "disagree", "neutral", "agree", "strongly_agree"]
|
||||
confidence: Literal["low", "medium", "high"]
|
||||
|
||||
generator = outlines.generate.json(model, Survey)
|
||||
survey = generator("Rate: 'I enjoy using this product'")
|
||||
```
|
||||
|
||||
## Nested Structures
|
||||
|
||||
### Nested Models
|
||||
|
||||
```python
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
state: str
|
||||
zip_code: str
|
||||
country: str = "USA"
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
address: Address # Nested model
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, Person)
|
||||
|
||||
prompt = """
|
||||
Extract person:
|
||||
Name: Alice Johnson
|
||||
Age: 28
|
||||
Email: alice@example.com
|
||||
Address: 123 Main St, Boston, MA, 02101
|
||||
"""
|
||||
|
||||
person = generator(prompt)
|
||||
print(person.name) # "Alice Johnson"
|
||||
print(person.address.city) # "Boston"
|
||||
print(person.address.state) # "MA"
|
||||
```
|
||||
|
||||
### Deep Nesting
|
||||
|
||||
```python
|
||||
class Coordinates(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
|
||||
class Location(BaseModel):
|
||||
name: str
|
||||
coordinates: Coordinates
|
||||
|
||||
class Event(BaseModel):
|
||||
title: str
|
||||
date: str
|
||||
location: Location
|
||||
|
||||
generator = outlines.generate.json(model, Event)
|
||||
event = generator("Generate event: Tech Conference in San Francisco")
|
||||
|
||||
print(event.title) # "Tech Conference"
|
||||
print(event.location.name) # "San Francisco"
|
||||
print(event.location.coordinates.latitude) # 37.7749
|
||||
```
|
||||
|
||||
### Lists of Nested Models
|
||||
|
||||
```python
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
quantity: int
|
||||
price: float
|
||||
|
||||
class Order(BaseModel):
|
||||
order_id: str
|
||||
customer: str
|
||||
items: list[Item] # List of nested models
|
||||
total: float
|
||||
|
||||
generator = outlines.generate.json(model, Order)
|
||||
|
||||
prompt = """
|
||||
Generate order for John:
|
||||
- 2x Widget ($10 each)
|
||||
- 3x Gadget ($15 each)
|
||||
Order ID: ORD-001
|
||||
"""
|
||||
|
||||
order = generator(prompt)
|
||||
print(f"Order ID: {order.order_id}")
|
||||
for item in order.items:
|
||||
print(f"- {item.quantity}x {item.name} @ ${item.price}")
|
||||
print(f"Total: ${order.total}")
|
||||
```
|
||||
|
||||
## Complex Types
|
||||
|
||||
### Union Types
|
||||
|
||||
```python
|
||||
from typing import Union
|
||||
|
||||
class TextContent(BaseModel):
|
||||
type: Literal["text"]
|
||||
content: str
|
||||
|
||||
class ImageContent(BaseModel):
|
||||
type: Literal["image"]
|
||||
url: str
|
||||
caption: str
|
||||
|
||||
class Post(BaseModel):
|
||||
title: str
|
||||
content: Union[TextContent, ImageContent] # Either type
|
||||
|
||||
generator = outlines.generate.json(model, Post)
|
||||
|
||||
# Can generate either text or image content
|
||||
post = generator("Generate blog post with image")
|
||||
if post.content.type == "text":
|
||||
print(post.content.content)
|
||||
elif post.content.type == "image":
|
||||
print(post.content.url)
|
||||
```
|
||||
|
||||
### Lists and Arrays
|
||||
|
||||
```python
|
||||
class Article(BaseModel):
|
||||
title: str
|
||||
authors: list[str] # List of strings
|
||||
tags: list[str]
|
||||
sections: list[dict[str, str]] # List of dicts
|
||||
related_ids: list[int]
|
||||
|
||||
generator = outlines.generate.json(model, Article)
|
||||
article = generator("Generate article about AI")
|
||||
|
||||
print(article.authors) # ["Alice", "Bob"]
|
||||
print(article.tags) # ["AI", "Machine Learning", "Technology"]
|
||||
```
|
||||
|
||||
### Dictionaries
|
||||
|
||||
```python
|
||||
class Metadata(BaseModel):
|
||||
title: str
|
||||
properties: dict[str, str] # String keys and values
|
||||
counts: dict[str, int] # String keys, int values
|
||||
settings: dict[str, Union[str, int, bool]] # Mixed value types
|
||||
|
||||
generator = outlines.generate.json(model, Metadata)
|
||||
meta = generator("Generate metadata")
|
||||
|
||||
print(meta.properties) # {"author": "Alice", "version": "1.0"}
|
||||
print(meta.counts) # {"views": 1000, "likes": 50}
|
||||
```
|
||||
|
||||
### Any Type (Use Sparingly)
|
||||
|
||||
```python
|
||||
from typing import Any
|
||||
|
||||
class FlexibleData(BaseModel):
|
||||
name: str
|
||||
structured_field: str
|
||||
flexible_field: Any # Can be anything
|
||||
|
||||
# Note: Any reduces type safety, use only when necessary
|
||||
generator = outlines.generate.json(model, FlexibleData)
|
||||
```
|
||||
|
||||
## JSON Schema Support
|
||||
|
||||
### Direct Schema Usage
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Define JSON schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer", "minimum": 0, "maximum": 120},
|
||||
"email": {"type": "string", "format": "email"}
|
||||
},
|
||||
"required": ["name", "age", "email"]
|
||||
}
|
||||
|
||||
# Generate from schema
|
||||
generator = outlines.generate.json(model, schema)
|
||||
result = generator("Generate person: Alice, 25, alice@example.com")
|
||||
|
||||
print(result) # Valid JSON matching schema
|
||||
```
|
||||
|
||||
### Schema from Pydantic
|
||||
|
||||
```python
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
# Get JSON schema from Pydantic model
|
||||
schema = User.model_json_schema()
|
||||
print(schema)
|
||||
# {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "name": {"type": "string"},
|
||||
# "age": {"type": "integer"},
|
||||
# "email": {"type": "string"}
|
||||
# },
|
||||
# "required": ["name", "age", "email"]
|
||||
# }
|
||||
|
||||
# Both approaches equivalent:
|
||||
generator1 = outlines.generate.json(model, User)
|
||||
generator2 = outlines.generate.json(model, schema)
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Conditional Fields
|
||||
|
||||
```python
|
||||
class Order(BaseModel):
|
||||
order_type: Literal["standard", "express"]
|
||||
delivery_date: str
|
||||
express_fee: Optional[float] = None # Only for express orders
|
||||
|
||||
generator = outlines.generate.json(model, Order)
|
||||
|
||||
# Express order
|
||||
order1 = generator("Create express order for tomorrow")
|
||||
print(order1.express_fee) # 25.0
|
||||
|
||||
# Standard order
|
||||
order2 = generator("Create standard order")
|
||||
print(order2.express_fee) # None
|
||||
```
|
||||
|
||||
### Recursive Models
|
||||
|
||||
```python
|
||||
from typing import Optional, List
|
||||
|
||||
class TreeNode(BaseModel):
|
||||
value: str
|
||||
children: Optional[List['TreeNode']] = None
|
||||
|
||||
# Enable forward references
|
||||
TreeNode.model_rebuild()
|
||||
|
||||
generator = outlines.generate.json(model, TreeNode)
|
||||
tree = generator("Generate file tree with subdirectories")
|
||||
|
||||
print(tree.value) # "root"
|
||||
print(tree.children[0].value) # "subdir1"
|
||||
```
|
||||
|
||||
### Model with Validation
|
||||
|
||||
```python
|
||||
from pydantic import field_validator
|
||||
|
||||
class DateRange(BaseModel):
|
||||
start_date: str
|
||||
end_date: str
|
||||
|
||||
@field_validator('end_date')
|
||||
def end_after_start(cls, v, info):
|
||||
"""Ensure end_date is after start_date."""
|
||||
if 'start_date' in info.data:
|
||||
from datetime import datetime
|
||||
start = datetime.strptime(info.data['start_date'], '%Y-%m-%d')
|
||||
end = datetime.strptime(v, '%Y-%m-%d')
|
||||
if end < start:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return v
|
||||
|
||||
generator = outlines.generate.json(model, DateRange)
|
||||
# Validation happens after generation
|
||||
```
|
||||
|
||||
## Multiple Objects
|
||||
|
||||
### Generate List of Objects
|
||||
|
||||
```python
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
class Team(BaseModel):
|
||||
team_name: str
|
||||
members: list[Person]
|
||||
|
||||
generator = outlines.generate.json(model, Team)
|
||||
|
||||
team = generator("Generate engineering team with 5 members")
|
||||
print(f"Team: {team.team_name}")
|
||||
for member in team.members:
|
||||
print(f"- {member.name}, {member.age}")
|
||||
```
|
||||
|
||||
### Batch Generation
|
||||
|
||||
```python
|
||||
def generate_batch(prompts: list[str], schema: type[BaseModel]):
|
||||
"""Generate structured outputs for multiple prompts."""
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
results = []
|
||||
for prompt in prompts:
|
||||
result = generator(prompt)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
|
||||
prompts = [
|
||||
"Product: iPhone 15, $999",
|
||||
"Product: MacBook Pro, $2499",
|
||||
"Product: AirPods, $179"
|
||||
]
|
||||
|
||||
products = generate_batch(prompts, Product)
|
||||
for product in products:
|
||||
print(f"{product.name}: ${product.price}")
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Caching Generators
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def get_generator(model_name: str, schema_hash: int):
|
||||
"""Cache generators for reuse."""
|
||||
model = outlines.models.transformers(model_name)
|
||||
return outlines.generate.json(model, schema)
|
||||
|
||||
# First call: creates generator
|
||||
gen1 = get_generator("microsoft/Phi-3-mini-4k-instruct", hash(User))
|
||||
|
||||
# Second call: returns cached generator (fast!)
|
||||
gen2 = get_generator("microsoft/Phi-3-mini-4k-instruct", hash(User))
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
# Process multiple items efficiently
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, User)
|
||||
|
||||
texts = ["User: Alice, 25", "User: Bob, 30", "User: Carol, 35"]
|
||||
|
||||
# Reuse generator (model stays loaded)
|
||||
users = [generator(text) for text in texts]
|
||||
```
|
||||
|
||||
### Minimize Schema Complexity
|
||||
|
||||
```python
|
||||
# ✅ Good: Simple, flat structure (faster)
|
||||
class SimplePerson(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
city: str
|
||||
|
||||
# ⚠️ Slower: Deep nesting
|
||||
class ComplexPerson(BaseModel):
|
||||
personal_info: PersonalInfo
|
||||
address: Address
|
||||
employment: Employment
|
||||
# ... many nested levels
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Handle Missing Fields
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
try:
|
||||
user = generator("Generate user") # May not include all fields
|
||||
except ValidationError as e:
|
||||
print(f"Validation error: {e}")
|
||||
# Handle gracefully
|
||||
```
|
||||
|
||||
### Fallback with Optional Fields
|
||||
|
||||
```python
|
||||
class RobustUser(BaseModel):
|
||||
name: str # Required
|
||||
age: Optional[int] = None # Optional
|
||||
email: Optional[str] = None # Optional
|
||||
|
||||
# More likely to succeed even with incomplete data
|
||||
user = generator("Generate user: Alice")
|
||||
print(user.name) # "Alice"
|
||||
print(user.age) # None (not provided)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Specific Types
|
||||
|
||||
```python
|
||||
# ✅ Good: Specific types
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float # Not Any or str
|
||||
quantity: int # Not str
|
||||
in_stock: bool # Not int
|
||||
|
||||
# ❌ Bad: Generic types
|
||||
class Product(BaseModel):
|
||||
name: Any
|
||||
price: str # Should be float
|
||||
quantity: str # Should be int
|
||||
```
|
||||
|
||||
### 2. Add Descriptions
|
||||
|
||||
```python
|
||||
# ✅ Good: Clear descriptions
|
||||
class Article(BaseModel):
|
||||
title: str = Field(description="Article title, 10-100 characters")
|
||||
content: str = Field(description="Main article content in paragraphs")
|
||||
tags: list[str] = Field(description="List of relevant topic tags")
|
||||
|
||||
# Descriptions help the model understand expected output
|
||||
```
|
||||
|
||||
### 3. Use Constraints
|
||||
|
||||
```python
|
||||
# ✅ Good: With constraints
|
||||
class Age(BaseModel):
|
||||
value: int = Field(ge=0, le=120, description="Age in years")
|
||||
|
||||
# ❌ Bad: No constraints
|
||||
class Age(BaseModel):
|
||||
value: int # Could be negative or > 120
|
||||
```
|
||||
|
||||
### 4. Prefer Enums Over Strings
|
||||
|
||||
```python
|
||||
# ✅ Good: Enum for fixed set
|
||||
class Priority(str, Enum):
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
|
||||
class Task(BaseModel):
|
||||
priority: Priority # Guaranteed valid
|
||||
|
||||
# ❌ Bad: Free-form string
|
||||
class Task(BaseModel):
|
||||
priority: str # Could be "urgent", "ASAP", "!!", etc.
|
||||
```
|
||||
|
||||
### 5. Test Your Models
|
||||
|
||||
```python
|
||||
# Test models work as expected
|
||||
def test_product_model():
|
||||
product = Product(
|
||||
name="Test Product",
|
||||
price=19.99,
|
||||
quantity=10,
|
||||
in_stock=True
|
||||
)
|
||||
assert product.price == 19.99
|
||||
assert isinstance(product, Product)
|
||||
|
||||
# Run tests before using in production
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Pydantic Docs**: https://docs.pydantic.dev
|
||||
- **JSON Schema**: https://json-schema.org
|
||||
- **Outlines GitHub**: https://github.com/outlines-dev/outlines
|
||||
166
optional-skills/mlops/training/axolotl/SKILL.md
Normal file
166
optional-skills/mlops/training/axolotl/SKILL.md
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
---
|
||||
name: axolotl
|
||||
description: "Axolotl: YAML LLM fine-tuning (LoRA, DPO, GRPO)."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [axolotl, torch, transformers, datasets, peft, accelerate, deepspeed]
|
||||
platforms: [linux, macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Fine-Tuning, Axolotl, LLM, LoRA, QLoRA, DPO, KTO, ORPO, GRPO, YAML, HuggingFace, DeepSpeed, Multimodal]
|
||||
|
||||
---
|
||||
|
||||
# Axolotl Skill
|
||||
|
||||
## What's inside
|
||||
|
||||
Expert guidance for fine-tuning LLMs with Axolotl — YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support.
|
||||
|
||||
Comprehensive assistance with axolotl development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with axolotl
|
||||
- Asking about axolotl features or APIs
|
||||
- Implementing axolotl solutions
|
||||
- Debugging axolotl code
|
||||
- Learning axolotl best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
**Pattern 1:** To validate that acceptable data transfer speeds exist for your training job, running NCCL Tests can help pinpoint bottlenecks, for example:
|
||||
|
||||
```
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
**Pattern 2:** Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||
|
||||
```
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
```
|
||||
|
||||
**Pattern 3:** The context_parallel_size should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
```
|
||||
context_parallel_size
|
||||
```
|
||||
|
||||
**Pattern 4:** For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and context_parallel_size=4: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
```
|
||||
context_parallel_size=4
|
||||
```
|
||||
|
||||
**Pattern 5:** Setting save_compressed: true in your configuration enables saving models in a compressed format, which: - Reduces disk space usage by approximately 40% - Maintains compatibility with vLLM for accelerated inference - Maintains compatibility with llmcompressor for further optimization (example: quantization)
|
||||
|
||||
```
|
||||
save_compressed: true
|
||||
```
|
||||
|
||||
**Pattern 6:** Note It is not necessary to place your integration in the integrations folder. It can be in any location, so long as it’s installed in a package in your python env. See this repo for an example: https://github.com/axolotl-ai-cloud/diff-transformer
|
||||
|
||||
```
|
||||
integrations
|
||||
```
|
||||
|
||||
**Pattern 7:** Handle both single-example and batched data. - single example: sample[‘input_ids’] is a list[int] - batched data: sample[‘input_ids’] is a list[list[int]]
|
||||
|
||||
```
|
||||
utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)
|
||||
```
|
||||
|
||||
### Example Code Patterns
|
||||
|
||||
**Example 1** (python):
|
||||
```python
|
||||
cli.cloud.modal_.ModalCloud(config, app=None)
|
||||
```
|
||||
|
||||
**Example 2** (python):
|
||||
```python
|
||||
cli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)
|
||||
```
|
||||
|
||||
**Example 3** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer(
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
**Example 4** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer.log(logs, start_time=None)
|
||||
```
|
||||
|
||||
**Example 5** (python):
|
||||
```python
|
||||
prompt_strategies.input_output.RawInputOutputPrompter()
|
||||
```
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **api.md** - Api documentation
|
||||
- **dataset-formats.md** - Dataset-Formats documentation
|
||||
- **other.md** - Other documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
|
||||
5548
optional-skills/mlops/training/axolotl/references/api.md
Normal file
5548
optional-skills/mlops/training/axolotl/references/api.md
Normal file
File diff suppressed because it is too large
Load diff
1029
optional-skills/mlops/training/axolotl/references/dataset-formats.md
Normal file
1029
optional-skills/mlops/training/axolotl/references/dataset-formats.md
Normal file
File diff suppressed because it is too large
Load diff
15
optional-skills/mlops/training/axolotl/references/index.md
Normal file
15
optional-skills/mlops/training/axolotl/references/index.md
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# Axolotl Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Api
|
||||
**File:** `api.md`
|
||||
**Pages:** 150
|
||||
|
||||
### Dataset-Formats
|
||||
**File:** `dataset-formats.md`
|
||||
**Pages:** 9
|
||||
|
||||
### Other
|
||||
**File:** `other.md`
|
||||
**Pages:** 26
|
||||
3563
optional-skills/mlops/training/axolotl/references/other.md
Normal file
3563
optional-skills/mlops/training/axolotl/references/other.md
Normal file
File diff suppressed because it is too large
Load diff
463
optional-skills/mlops/training/trl-fine-tuning/SKILL.md
Normal file
463
optional-skills/mlops/training/trl-fine-tuning/SKILL.md
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
---
|
||||
name: fine-tuning-with-trl
|
||||
description: "TRL: SFT, DPO, PPO, GRPO, reward modeling for LLM RLHF."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [trl, transformers, datasets, peft, accelerate, torch]
|
||||
platforms: [linux, macos, windows]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, TRL, Reinforcement Learning, Fine-Tuning, SFT, DPO, PPO, GRPO, RLHF, Preference Alignment, HuggingFace]
|
||||
|
||||
---
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
## Quick start
|
||||
|
||||
TRL provides post-training methods for aligning language models with human preferences.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install trl transformers datasets peft accelerate
|
||||
```
|
||||
|
||||
**Supervised Fine-Tuning** (instruction tuning):
|
||||
```python
|
||||
from trl import SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset, # Prompt-completion pairs
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**DPO** (align with preferences):
|
||||
```python
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
config = DPOConfig(output_dir="model-dpo", beta=0.1)
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=preference_dataset, # chosen/rejected pairs
|
||||
processing_class=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Full RLHF pipeline (SFT → Reward Model → PPO)
|
||||
|
||||
Complete pipeline from base model to human-aligned model.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
RLHF Training:
|
||||
- [ ] Step 1: Supervised fine-tuning (SFT)
|
||||
- [ ] Step 2: Train reward model
|
||||
- [ ] Step 3: PPO reinforcement learning
|
||||
- [ ] Step 4: Evaluate aligned model
|
||||
```
|
||||
|
||||
**Step 1: Supervised fine-tuning**
|
||||
|
||||
Train base model on instruction-following data:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
# Load instruction dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure training
|
||||
training_args = SFTConfig(
|
||||
output_dir="Qwen2.5-0.5B-SFT",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=2e-5,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**Step 2: Train reward model**
|
||||
|
||||
Train model to predict human preferences:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
# Load SFT model as base
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen2.5-0.5B-SFT",
|
||||
num_labels=1 # Single reward score
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-0.5B-SFT")
|
||||
|
||||
# Load preference data (chosen/rejected pairs)
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Configure training
|
||||
training_args = RewardConfig(
|
||||
output_dir="Qwen2.5-0.5B-Reward",
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5
|
||||
)
|
||||
|
||||
# Train reward model
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**Step 3: PPO reinforcement learning**
|
||||
|
||||
Optimize policy using reward model:
|
||||
|
||||
```bash
|
||||
python -m trl.scripts.ppo \
|
||||
--model_name_or_path Qwen2.5-0.5B-SFT \
|
||||
--reward_model_path Qwen2.5-0.5B-Reward \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--output_dir Qwen2.5-0.5B-PPO \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--total_episodes 10000
|
||||
```
|
||||
|
||||
**Step 4: Evaluate**
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Load aligned model
|
||||
generator = pipeline("text-generation", model="Qwen2.5-0.5B-PPO")
|
||||
|
||||
# Test
|
||||
prompt = "Explain quantum computing to a 10-year-old"
|
||||
output = generator(prompt, max_length=200)[0]["generated_text"]
|
||||
print(output)
|
||||
```
|
||||
|
||||
### Workflow 2: Simple preference alignment with DPO
|
||||
|
||||
Align model with preferences without reward model.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
DPO Training:
|
||||
- [ ] Step 1: Prepare preference dataset
|
||||
- [ ] Step 2: Configure DPO
|
||||
- [ ] Step 3: Train with DPOTrainer
|
||||
- [ ] Step 4: Evaluate alignment
|
||||
```
|
||||
|
||||
**Step 1: Prepare preference dataset**
|
||||
|
||||
Dataset format:
|
||||
```json
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"chosen": "The capital of France is Paris.",
|
||||
"rejected": "I don't know."
|
||||
}
|
||||
```
|
||||
|
||||
Load dataset:
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
# Or load your own
|
||||
# dataset = load_dataset("json", data_files="preferences.json")
|
||||
```
|
||||
|
||||
**Step 2: Configure DPO**
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
config = DPOConfig(
|
||||
output_dir="Qwen2.5-0.5B-DPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=5e-7,
|
||||
beta=0.1, # KL penalty strength
|
||||
max_prompt_length=512,
|
||||
max_length=1024,
|
||||
logging_steps=10
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Train with DPOTrainer**
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**CLI alternative**:
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO \
|
||||
--per_device_train_batch_size 4 \
|
||||
--learning_rate 5e-7 \
|
||||
--beta 0.1
|
||||
```
|
||||
|
||||
### Workflow 3: Memory-efficient online RL with GRPO
|
||||
|
||||
Train with reinforcement learning using minimal memory.
|
||||
|
||||
For in-depth GRPO guidance — reward function design, critical training insights (loss behavior, mode collapse, tuning), and advanced multi-stage patterns — see **[references/grpo-training.md](references/grpo-training.md)**. A production-ready training script is in **[templates/basic_grpo_training.py](templates/basic_grpo_training.py)**.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
GRPO Training:
|
||||
- [ ] Step 1: Define reward function
|
||||
- [ ] Step 2: Configure GRPO
|
||||
- [ ] Step 3: Train with GRPOTrainer
|
||||
```
|
||||
|
||||
**Step 1: Define reward function**
|
||||
|
||||
```python
|
||||
def reward_function(completions, **kwargs):
|
||||
"""
|
||||
Compute rewards for completions.
|
||||
|
||||
Args:
|
||||
completions: List of generated texts
|
||||
|
||||
Returns:
|
||||
List of reward scores (floats)
|
||||
"""
|
||||
rewards = []
|
||||
for completion in completions:
|
||||
# Example: reward based on length and unique words
|
||||
score = len(completion.split()) # Favor longer responses
|
||||
score += len(set(completion.lower().split())) # Reward unique words
|
||||
rewards.append(score)
|
||||
return rewards
|
||||
```
|
||||
|
||||
Or use a reward model:
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
reward_model = pipeline("text-classification", model="reward-model-path")
|
||||
|
||||
def reward_from_model(completions, prompts, **kwargs):
|
||||
# Combine prompt + completion
|
||||
full_texts = [p + c for p, c in zip(prompts, completions)]
|
||||
# Get reward scores
|
||||
results = reward_model(full_texts)
|
||||
return [r["score"] for r in results]
|
||||
```
|
||||
|
||||
**Step 2: Configure GRPO**
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir="Qwen2-GRPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5,
|
||||
num_generations=4, # Generate 4 completions per prompt
|
||||
max_new_tokens=128
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Train with GRPOTrainer**
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load prompt-only dataset
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_function, # Your reward function
|
||||
args=config,
|
||||
train_dataset=dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**CLI**:
|
||||
```bash
|
||||
trl grpo \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--output_dir Qwen2-GRPO \
|
||||
--num_generations 4
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use TRL when:**
|
||||
- Need to align model with human preferences
|
||||
- Have preference data (chosen/rejected pairs)
|
||||
- Want to use reinforcement learning (PPO, GRPO)
|
||||
- Need reward model training
|
||||
- Doing RLHF (full pipeline)
|
||||
|
||||
**Method selection**:
|
||||
- **SFT**: Have prompt-completion pairs, want basic instruction following
|
||||
- **DPO**: Have preferences, want simple alignment (no reward model needed)
|
||||
- **PPO**: Have reward model, need maximum control over RL
|
||||
- **GRPO**: Memory-constrained, want online RL
|
||||
- **Reward Model**: Building RLHF pipeline, need to score generations
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **HuggingFace Trainer**: Basic fine-tuning without RL
|
||||
- **Axolotl**: YAML-based training configuration
|
||||
- **LitGPT**: Educational, minimal fine-tuning
|
||||
- **Unsloth**: Fast LoRA training
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: OOM during DPO training**
|
||||
|
||||
Reduce batch size and sequence length:
|
||||
```python
|
||||
config = DPOConfig(
|
||||
per_device_train_batch_size=1, # Reduce from 4
|
||||
max_length=512, # Reduce from 1024
|
||||
gradient_accumulation_steps=8 # Maintain effective batch
|
||||
)
|
||||
```
|
||||
|
||||
Or use gradient checkpointing:
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Issue: Poor alignment quality**
|
||||
|
||||
Tune beta parameter:
|
||||
```python
|
||||
# Higher beta = more conservative (stays closer to reference)
|
||||
config = DPOConfig(beta=0.5) # Default 0.1
|
||||
|
||||
# Lower beta = more aggressive alignment
|
||||
config = DPOConfig(beta=0.01)
|
||||
```
|
||||
|
||||
**Issue: Reward model not learning**
|
||||
|
||||
Check loss type and learning rate:
|
||||
```python
|
||||
config = RewardConfig(
|
||||
learning_rate=1e-5, # Try different LR
|
||||
num_train_epochs=3 # Train longer
|
||||
)
|
||||
```
|
||||
|
||||
Ensure preference dataset has clear winners:
|
||||
```python
|
||||
# Verify dataset
|
||||
print(dataset[0])
|
||||
# Should have clear chosen > rejected
|
||||
```
|
||||
|
||||
**Issue: PPO training unstable**
|
||||
|
||||
Adjust KL coefficient:
|
||||
```python
|
||||
config = PPOConfig(
|
||||
kl_coef=0.1, # Increase from 0.05
|
||||
cliprange=0.1 # Reduce from 0.2
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**SFT training guide**: See [references/sft-training.md](references/sft-training.md) for dataset formats, chat templates, packing strategies, and multi-GPU training.
|
||||
|
||||
**DPO variants**: See [references/dpo-variants.md](references/dpo-variants.md) for IPO, cDPO, RPO, and other DPO loss functions with recommended hyperparameters.
|
||||
|
||||
**Reward modeling**: See [references/reward-modeling.md](references/reward-modeling.md) for outcome vs process rewards, Bradley-Terry loss, and reward model evaluation.
|
||||
|
||||
**Online RL methods**: See [references/online-rl.md](references/online-rl.md) for PPO, GRPO, RLOO, and OnlineDPO with detailed configurations.
|
||||
|
||||
**GRPO deep dive**: See [references/grpo-training.md](references/grpo-training.md) for expert-level GRPO patterns — reward function design philosophy, training insights (why loss increases, mode collapse detection), hyperparameter tuning, multi-stage training, and troubleshooting. Production-ready template in [templates/basic_grpo_training.py](templates/basic_grpo_training.py).
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA required)
|
||||
- **VRAM**: Depends on model and method
|
||||
- SFT 7B: 16GB (with LoRA)
|
||||
- DPO 7B: 24GB (stores reference model)
|
||||
- PPO 7B: 40GB (policy + reward model)
|
||||
- GRPO 7B: 24GB (more memory efficient)
|
||||
- **Multi-GPU**: Supported via `accelerate`
|
||||
- **Mixed precision**: BF16 recommended (A100/H100)
|
||||
|
||||
**Memory optimization**:
|
||||
- Use LoRA/QLoRA for all methods
|
||||
- Enable gradient checkpointing
|
||||
- Use smaller batch sizes with gradient accumulation
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/trl/
|
||||
- GitHub: https://github.com/huggingface/trl
|
||||
- Papers:
|
||||
- "Training language models to follow instructions with human feedback" (InstructGPT, 2022)
|
||||
- "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" (DPO, 2023)
|
||||
- "Group Relative Policy Optimization" (GRPO, 2024)
|
||||
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
# DPO Variants
|
||||
|
||||
Complete guide to Direct Preference Optimization loss variants in TRL.
|
||||
|
||||
## Overview
|
||||
|
||||
DPO optimizes models using preference data (chosen/rejected pairs). TRL supports 10+ loss variants for different scenarios.
|
||||
|
||||
## Loss Types
|
||||
|
||||
### 1. Sigmoid (Standard DPO)
|
||||
|
||||
**Formula**: `-log(sigmoid(β * logits))`
|
||||
|
||||
**When to use**: Default choice, general preference alignment
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="sigmoid",
|
||||
beta=0.1, # KL penalty
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=1e-6
|
||||
)
|
||||
```
|
||||
|
||||
### 2. IPO (Identity Policy Optimization)
|
||||
|
||||
**Formula**: `(logits - 1/(2β))²`
|
||||
|
||||
**When to use**: Better theoretical foundation, reduce overfitting
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="ipo",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=90,
|
||||
learning_rate=1e-2
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Hinge (SLiC)
|
||||
|
||||
**Formula**: `ReLU(1 - β * logits)`
|
||||
|
||||
**When to use**: Margin-based objective
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="hinge",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=512,
|
||||
learning_rate=1e-4
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Robust DPO
|
||||
|
||||
**Formula**: Sigmoid with label smoothing for noise robustness
|
||||
|
||||
**When to use**: Noisy preference labels
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="robust",
|
||||
beta=0.01,
|
||||
label_smoothing=0.1, # Noise probability
|
||||
per_device_train_batch_size=16,
|
||||
learning_rate=1e-3,
|
||||
max_prompt_length=128,
|
||||
max_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 5. BCO Pair (Binary Classification)
|
||||
|
||||
**Formula**: Train binary classifier (chosen=1, rejected=0)
|
||||
|
||||
**When to use**: Pairwise preference data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="bco_pair",
|
||||
beta=0.01,
|
||||
per_device_train_batch_size=128,
|
||||
learning_rate=5e-7,
|
||||
max_prompt_length=1536,
|
||||
max_completion_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 6. SPPO Hard
|
||||
|
||||
**Formula**: Push chosen→0.5, rejected→-0.5
|
||||
|
||||
**When to use**: Nash equilibrium, sparse data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="sppo_hard",
|
||||
beta=0.1
|
||||
)
|
||||
```
|
||||
|
||||
### 7. DiscoPOP
|
||||
|
||||
**Formula**: Log-Ratio Modulated Loss
|
||||
|
||||
**When to use**: Automated loss discovery
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="discopop",
|
||||
beta=0.05,
|
||||
discopop_tau=0.05,
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=5e-7
|
||||
)
|
||||
```
|
||||
|
||||
### 8. APO Zero
|
||||
|
||||
**Formula**: Increase chosen, decrease rejected likelihood
|
||||
|
||||
**When to use**: Model worse than winning outputs
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="apo_zero",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=2e-7,
|
||||
max_prompt_length=512,
|
||||
max_completion_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 9. APO Down
|
||||
|
||||
**Formula**: Decrease both, emphasize rejected reduction
|
||||
|
||||
**When to use**: Model better than winning outputs
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="apo_down",
|
||||
beta=0.1,
|
||||
# Same hyperparameters as apo_zero
|
||||
)
|
||||
```
|
||||
|
||||
### 10. AOT & AOT Pair
|
||||
|
||||
**Formula**: Distributional alignment via stochastic dominance
|
||||
|
||||
**When to use**:
|
||||
- `aot_pair`: Paired preference data
|
||||
- `aot`: Unpaired data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="aot_pair", # or "aot"
|
||||
beta=0.1,
|
||||
label_smoothing=0.0
|
||||
)
|
||||
```
|
||||
|
||||
## Multi-Loss Training
|
||||
|
||||
Combine multiple losses:
|
||||
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type=["sigmoid", "ipo"],
|
||||
loss_weights=[0.7, 0.3], # Weighted combination
|
||||
beta=0.1
|
||||
)
|
||||
```
|
||||
|
||||
## Key Parameters
|
||||
|
||||
### Beta (β)
|
||||
|
||||
Controls deviation from reference model:
|
||||
- **Higher** (0.5): More conservative, stays close to reference
|
||||
- **Lower** (0.01): More aggressive alignment
|
||||
- **Default**: 0.1
|
||||
|
||||
### Label Smoothing
|
||||
|
||||
For robust DPO:
|
||||
- **0.0**: No smoothing (default)
|
||||
- **0.1-0.3**: Moderate noise robustness
|
||||
- **0.5**: Maximum noise tolerance
|
||||
|
||||
### Max Lengths
|
||||
|
||||
- `max_prompt_length`: 128-1536
|
||||
- `max_completion_length`: 128-512
|
||||
- `max_length`: Total sequence (1024-2048)
|
||||
|
||||
## Comparison Table
|
||||
|
||||
| Loss | Speed | Stability | Best For |
|
||||
|------|-------|-----------|----------|
|
||||
| Sigmoid | Fast | Good | **General use** |
|
||||
| IPO | Fast | Better | Overfitting issues |
|
||||
| Hinge | Fast | Good | Margin objectives |
|
||||
| Robust | Fast | Best | Noisy data |
|
||||
| BCO | Medium | Good | Binary classification |
|
||||
| DiscoPOP | Fast | Good | New architectures |
|
||||
| APO | Fast | Good | Model quality matching |
|
||||
|
||||
## References
|
||||
|
||||
- DPO paper: https://arxiv.org/abs/2305.18290
|
||||
- IPO paper: https://arxiv.org/abs/2310.12036
|
||||
- TRL docs: https://huggingface.co/docs/trl/dpo_trainer
|
||||
|
|
@ -0,0 +1,504 @@
|
|||
# GRPO (Group Relative Policy Optimization) — Deep Guide
|
||||
|
||||
Expert-level patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions using TRL's `GRPOTrainer`. This is the deep reference for the GRPO workflow summarized in the main skill.
|
||||
|
||||
## When to use GRPO
|
||||
|
||||
Use GRPO when you need to:
|
||||
- **Enforce specific output formats** (XML tags, JSON, structured reasoning)
|
||||
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
|
||||
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
|
||||
- **Align models to domain-specific behaviors** without labeled preference data
|
||||
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
|
||||
|
||||
**Do NOT use GRPO for:**
|
||||
- Simple supervised fine-tuning tasks → use SFT
|
||||
- Tasks without clear reward signals
|
||||
- When you already have high-quality preference pairs → use DPO/PPO
|
||||
|
||||
## Core concepts
|
||||
|
||||
### 1. GRPO algorithm fundamentals
|
||||
|
||||
**Key mechanism:**
|
||||
- Generates **multiple completions** per prompt (group size: 4–16)
|
||||
- Compares completions within each group using reward functions
|
||||
- Updates policy to favor higher-rewarded responses relative to the group
|
||||
|
||||
**Critical differences from PPO:**
|
||||
- No separate reward model needed
|
||||
- More sample-efficient (learns from within-group comparisons)
|
||||
- Simpler to implement and debug
|
||||
|
||||
**Mathematical intuition:**
|
||||
```
|
||||
For each prompt p:
|
||||
1. Generate N completions: {c₁, c₂, ..., cₙ}
|
||||
2. Compute rewards: {r₁, r₂, ..., rₙ}
|
||||
3. Learn to increase probability of high-reward completions
|
||||
relative to low-reward ones in the same group
|
||||
```
|
||||
|
||||
### 2. Reward function design philosophy
|
||||
|
||||
**Golden rules:**
|
||||
1. **Compose multiple reward functions** — each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** — higher weight = stronger signal
|
||||
3. **Use incremental rewards** — partial credit for partial compliance
|
||||
4. **Test rewards independently** — debug each reward function in isolation
|
||||
|
||||
**Reward function types:**
|
||||
|
||||
| Type | Use Case | Example Weight |
|
||||
|------|----------|----------------|
|
||||
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
|
||||
| **Format** | Strict structure enforcement | 0.5–1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1–0.5 |
|
||||
| **Style** | Penalize unwanted patterns | −0.5 to 0.5 |
|
||||
|
||||
## Implementation workflow
|
||||
|
||||
### Step 1: Dataset preparation
|
||||
|
||||
**Critical requirements:**
|
||||
- Prompts in chat format (list of dicts with `role` and `content`)
|
||||
- Include system prompts to set expectations
|
||||
- For verifiable tasks, include ground truth answers as additional columns
|
||||
|
||||
```python
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
def prepare_dataset(raw_data):
|
||||
"""Transform raw data into GRPO-compatible format.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content (system + user messages)
|
||||
- 'answer': str (ground truth, optional but recommended)
|
||||
"""
|
||||
return raw_data.map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': extract_answer(x['raw_answer'])
|
||||
})
|
||||
```
|
||||
|
||||
**Pro tips:**
|
||||
- Use one-shot or few-shot examples in the system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256–512 tokens)
|
||||
- Validate data quality before training (garbage in = garbage out)
|
||||
|
||||
### Step 2: Reward function implementation
|
||||
|
||||
**Template structure:**
|
||||
```python
|
||||
def reward_function_name(
|
||||
prompts, # List[List[Dict]]: Original prompts
|
||||
completions, # List[List[Dict]]: Model generations
|
||||
answer=None, # Optional: Ground truth from dataset
|
||||
**kwargs # Additional dataset columns
|
||||
) -> list[float]:
|
||||
"""Evaluate completions and return rewards (one per completion)."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
for response in responses:
|
||||
score = compute_score(response)
|
||||
rewards.append(score)
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Example 1: correctness reward (math/coding)**
|
||||
```python
|
||||
def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
"""Reward correct answers with high score."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_final_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
```
|
||||
|
||||
**Example 2: format reward (structured output)**
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward XML-like structured format."""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
|
||||
for r in responses]
|
||||
```
|
||||
|
||||
**Example 3: incremental format reward (partial credit)**
|
||||
```python
|
||||
def incremental_format_reward(completions, **kwargs):
|
||||
"""Award partial credit for format compliance."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r: score += 0.25
|
||||
if '</reasoning>' in r: score += 0.25
|
||||
if '<answer>' in r: score += 0.25
|
||||
if '</answer>' in r: score += 0.25
|
||||
# Penalize extra text after closing tag
|
||||
if r.count('</answer>') == 1:
|
||||
extra_text = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra_text) * 0.001
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Critical insight:** Combine 3–5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
|
||||
### Step 3: Training configuration
|
||||
|
||||
**Memory-optimized config (small GPU)**
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6, # Lower = more stable
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4, # Effective batch = 4
|
||||
|
||||
# GRPO-specific
|
||||
num_generations=8, # Group size: 8–16 recommended
|
||||
max_prompt_length=256,
|
||||
max_completion_length=512,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
max_steps=None,
|
||||
|
||||
# Optimization
|
||||
bf16=True, # Faster on A100/H100
|
||||
optim="adamw_8bit", # Memory-efficient optimizer
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb",
|
||||
)
|
||||
```
|
||||
|
||||
**High-performance config (large GPU)**
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
learning_rate=1e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=2,
|
||||
num_generations=16, # Larger groups = better signal
|
||||
max_prompt_length=512,
|
||||
max_completion_length=1024,
|
||||
num_train_epochs=1,
|
||||
bf16=True,
|
||||
use_vllm=True, # Fast generation with vLLM
|
||||
logging_steps=10,
|
||||
)
|
||||
```
|
||||
|
||||
**Critical hyperparameters:**
|
||||
|
||||
| Parameter | Impact | Tuning Advice |
|
||||
|-----------|--------|---------------|
|
||||
| `num_generations` | Group size for comparison | Start 8, increase to 16 if GPU allows |
|
||||
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 reasoning, 256 short answers) |
|
||||
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
|
||||
|
||||
### Step 4: Model setup and training
|
||||
|
||||
**Standard setup (Transformers + TRL)**
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer
|
||||
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # 2–3× faster
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Optional: LoRA for parameter-efficient training
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward,
|
||||
format_reward,
|
||||
correctness_reward,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model("final_model")
|
||||
```
|
||||
|
||||
**Unsloth setup (2–3× faster)**
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="google/gemma-3-1b-it",
|
||||
max_seq_length=1024,
|
||||
load_in_4bit=True,
|
||||
fast_inference=True,
|
||||
max_lora_rank=32,
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha=32,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Rest is identical to the standard setup
|
||||
trainer = GRPOTrainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Critical training insights
|
||||
|
||||
### 1. Loss behavior (EXPECTED pattern)
|
||||
- **Loss starts near 0 and INCREASES during training** — this is CORRECT
|
||||
- Loss measures KL divergence from initial policy; the model is learning (diverging from original behavior to optimize rewards)
|
||||
- **Monitor reward metrics, not loss, for progress**
|
||||
|
||||
### 2. Reward tracking
|
||||
|
||||
Key metrics to watch:
|
||||
- `reward` — average across all completions
|
||||
- `reward_std` — diversity within groups (should remain > 0)
|
||||
- `kl` — KL divergence from reference (should grow moderately)
|
||||
|
||||
**Healthy pattern:**
|
||||
```
|
||||
Step Reward Reward_Std KL
|
||||
100 0.5 0.3 0.02
|
||||
200 0.8 0.25 0.05
|
||||
300 1.2 0.2 0.08 ← Good progression
|
||||
400 1.5 0.15 0.12
|
||||
```
|
||||
|
||||
**Warning signs:**
|
||||
- `reward_std` → 0 (model collapsing to a single response)
|
||||
- `kl` exploding (> 0.5) — diverging too much, reduce LR
|
||||
- Reward stuck — reward functions too harsh or model capacity issue
|
||||
|
||||
### 3. Common pitfalls and solutions
|
||||
|
||||
| Problem | Symptom | Solution |
|
||||
|---------|---------|----------|
|
||||
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
|
||||
| **No learning** | Flat rewards | Check reward function logic, increase LR |
|
||||
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
|
||||
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
|
||||
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
|
||||
|
||||
## Advanced patterns
|
||||
|
||||
### 1. Multi-stage training
|
||||
|
||||
For complex tasks, train in stages:
|
||||
|
||||
```python
|
||||
# Stage 1: Format compliance
|
||||
trainer_stage1 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[incremental_format_reward, format_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage1.train()
|
||||
|
||||
# Stage 2: Correctness
|
||||
trainer_stage2 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[format_reward, correctness_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage2.train()
|
||||
```
|
||||
|
||||
### 2. Adaptive reward scaling
|
||||
|
||||
```python
|
||||
class AdaptiveReward:
|
||||
def __init__(self, base_reward_func, initial_weight=1.0):
|
||||
self.func = base_reward_func
|
||||
self.weight = initial_weight
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
rewards = self.func(*args, **kwargs)
|
||||
return [r * self.weight for r in rewards]
|
||||
|
||||
def adjust_weight(self, success_rate):
|
||||
"""Increase weight if model struggling, decrease if succeeding."""
|
||||
if success_rate < 0.3:
|
||||
self.weight *= 1.2
|
||||
elif success_rate > 0.8:
|
||||
self.weight *= 0.9
|
||||
```
|
||||
|
||||
### 3. Custom dataset integration
|
||||
|
||||
```python
|
||||
def load_custom_knowledge_base(csv_path):
|
||||
import pandas as pd
|
||||
df = pd.read_csv(csv_path)
|
||||
return Dataset.from_pandas(df).map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': x['expert_answer']
|
||||
})
|
||||
```
|
||||
|
||||
## Deployment and inference
|
||||
|
||||
### Save and merge LoRA
|
||||
```python
|
||||
if hasattr(trainer.model, 'merge_and_unload'):
|
||||
merged_model = trainer.model.merge_and_unload()
|
||||
merged_model.save_pretrained("production_model")
|
||||
tokenizer.save_pretrained("production_model")
|
||||
```
|
||||
|
||||
### Inference
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline("text-generation", model="production_model", tokenizer=tokenizer)
|
||||
|
||||
result = generator(
|
||||
[
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': "What is 15 + 27?"},
|
||||
],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
)
|
||||
print(result[0]['generated_text'])
|
||||
```
|
||||
|
||||
## Best practices checklist
|
||||
|
||||
**Before training:**
|
||||
- [ ] Validate dataset format (prompts as List[Dict])
|
||||
- [ ] Test reward functions on sample data
|
||||
- [ ] Calculate expected `max_prompt_length` from data
|
||||
- [ ] Choose `num_generations` based on GPU memory
|
||||
- [ ] Set up logging (wandb recommended)
|
||||
|
||||
**During training:**
|
||||
- [ ] Monitor reward progression (should increase)
|
||||
- [ ] Check `reward_std` (should stay > 0.1)
|
||||
- [ ] Watch for OOM errors (reduce batch size if needed)
|
||||
- [ ] Sample generations every 50–100 steps
|
||||
- [ ] Validate format compliance on holdout set
|
||||
|
||||
**After training:**
|
||||
- [ ] Merge LoRA weights if using PEFT
|
||||
- [ ] Test on diverse prompts
|
||||
- [ ] Compare to baseline model
|
||||
- [ ] Document reward weights and hyperparameters
|
||||
- [ ] Save reproducibility config
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Debugging workflow
|
||||
1. **Isolate reward functions** — test each independently
|
||||
2. **Check data distribution** — ensure diversity in prompts
|
||||
3. **Reduce complexity** — start with single reward, add gradually
|
||||
4. **Monitor generations** — print samples every N steps
|
||||
5. **Validate extraction logic** — ensure answer parsing works
|
||||
|
||||
### Quick debug reward
|
||||
```python
|
||||
def debug_reward(completions, **kwargs):
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
for i, r in enumerate(responses[:2]):
|
||||
print(f"Response {i}: {r[:200]}...")
|
||||
return [1.0] * len(responses)
|
||||
|
||||
# Test without training
|
||||
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
|
||||
trainer.generate_completions(dataset[:1])
|
||||
```
|
||||
|
||||
## Template
|
||||
|
||||
A production-ready training script lives at **`../templates/basic_grpo_training.py`**. It uses Qwen 2.5-1.5B-Instruct with LoRA and three reward functions (incremental format, strict format, correctness) on GSM8K. Copy and adapt:
|
||||
1. `get_dataset()` — swap in your data loader
|
||||
2. Reward functions — tune to your task
|
||||
3. `SYSTEM_PROMPT` — match your output format
|
||||
4. `GRPOConfig` — adjust hyperparameters for your GPU
|
||||
|
||||
## References and resources
|
||||
|
||||
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
|
||||
- GRPO paper (DeepSeek): https://arxiv.org/abs/2402.03300
|
||||
- DeepSeek R1 paper: https://arxiv.org/abs/2501.12948
|
||||
- Open R1 implementation: https://github.com/huggingface/open-r1
|
||||
- TRL examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
- Unsloth (faster training): https://docs.unsloth.ai/
|
||||
|
||||
## Critical reminders
|
||||
|
||||
- **Loss goes UP during training** — this is normal (it's KL divergence)
|
||||
- **Use 3–5 reward functions** — single rewards often fail
|
||||
- **Test rewards before training** — debug each function independently
|
||||
- **Monitor `reward_std`** — should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with `num_generations=4–8`** — scale up if GPU allows
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
# Online RL Methods
|
||||
|
||||
Guide to online reinforcement learning with PPO, GRPO, RLOO, and OnlineDPO.
|
||||
|
||||
## Overview
|
||||
|
||||
Online RL generates completions during training and optimizes based on rewards.
|
||||
|
||||
## PPO (Proximal Policy Optimization)
|
||||
|
||||
Classic RL algorithm for LLM alignment.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
python -m trl.scripts.ppo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--reward_model_path reward-model \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--output_dir model-ppo \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--total_episodes 10000 \
|
||||
--num_ppo_epochs 4 \
|
||||
--kl_coef 0.05
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `kl_coef`: KL penalty (0.05-0.2)
|
||||
- `num_ppo_epochs`: Epochs per batch (2-4)
|
||||
- `cliprange`: PPO clip (0.1-0.3)
|
||||
- `vf_coef`: Value function coef (0.1)
|
||||
|
||||
## GRPO (Group Relative Policy Optimization)
|
||||
|
||||
Memory-efficient online RL.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Define reward function
|
||||
def reward_func(completions, **kwargs):
|
||||
return [len(set(c.split())) for c in completions]
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir="model-grpo",
|
||||
num_generations=4, # Completions per prompt
|
||||
max_new_tokens=128
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_func,
|
||||
args=config,
|
||||
train_dataset=load_dataset("trl-lib/tldr", split="train")
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `num_generations`: 2-8 completions
|
||||
- `max_new_tokens`: 64-256
|
||||
- Learning rate: 1e-5 to 1e-4
|
||||
|
||||
## Memory Comparison
|
||||
|
||||
| Method | Memory (7B) | Speed | Use Case |
|
||||
|--------|-------------|-------|----------|
|
||||
| PPO | 40GB | Medium | Maximum control |
|
||||
| GRPO | 24GB | Fast | **Memory-constrained** |
|
||||
| OnlineDPO | 28GB | Fast | No reward model |
|
||||
|
||||
## References
|
||||
|
||||
- PPO paper: https://arxiv.org/abs/1707.06347
|
||||
- GRPO paper: https://arxiv.org/abs/2402.03300
|
||||
- TRL docs: https://huggingface.co/docs/trl/
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
# Reward Modeling
|
||||
|
||||
Guide to training reward models with TRL for RLHF pipelines.
|
||||
|
||||
## Overview
|
||||
|
||||
Reward models score completions based on human preferences. Used in:
|
||||
- PPO training (RL feedback)
|
||||
- GRPO online RL
|
||||
- Completion ranking
|
||||
|
||||
## Basic Training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model (num_labels=1 for single reward score)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
num_labels=1
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
# Load preference dataset (chosen/rejected pairs)
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Configure
|
||||
config = RewardConfig(
|
||||
output_dir="Qwen2.5-Reward",
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
Required fields:
|
||||
```json
|
||||
{
|
||||
"prompt": "Question or instruction",
|
||||
"chosen": "Better response",
|
||||
"rejected": "Worse response"
|
||||
}
|
||||
```
|
||||
|
||||
## Bradley-Terry Loss
|
||||
|
||||
Default loss function:
|
||||
```
|
||||
loss = -log(sigmoid(reward_chosen - reward_rejected))
|
||||
```
|
||||
|
||||
Learns to score chosen > rejected.
|
||||
|
||||
## Using Reward Models
|
||||
|
||||
### Inference
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Load trained reward model
|
||||
reward_pipe = pipeline("text-classification", model="Qwen2.5-Reward")
|
||||
|
||||
# Score completions
|
||||
texts = ["Good answer", "Bad answer"]
|
||||
scores = reward_pipe(texts)
|
||||
print(scores) # Higher score = better
|
||||
```
|
||||
|
||||
### In PPO
|
||||
|
||||
```python
|
||||
from trl import PPOTrainer, PPOConfig
|
||||
|
||||
config = PPOConfig(
|
||||
reward_model_path="Qwen2.5-Reward" # Use trained reward model
|
||||
)
|
||||
|
||||
trainer = PPOTrainer(
|
||||
model=policy_model,
|
||||
config=config,
|
||||
# Reward model loaded automatically
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
| Model Size | Learning Rate | Batch Size | Epochs |
|
||||
|------------|---------------|------------|--------|
|
||||
| <1B | 2e-5 | 4-8 | 1-2 |
|
||||
| 1-7B | 1e-5 | 2-4 | 1 |
|
||||
| 7-13B | 5e-6 | 1-2 | 1 |
|
||||
|
||||
## Evaluation
|
||||
|
||||
Check reward separation:
|
||||
```python
|
||||
# Chosen should score higher than rejected
|
||||
chosen_rewards = model(**chosen_inputs).logits
|
||||
rejected_rewards = model(**rejected_inputs).logits
|
||||
|
||||
accuracy = (chosen_rewards > rejected_rewards).float().mean()
|
||||
print(f"Accuracy: {accuracy:.2%}") # Target: >80%
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- InstructGPT paper: https://arxiv.org/abs/2203.02155
|
||||
- TRL docs: https://huggingface.co/docs/trl/reward_trainer
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
# SFT Training Guide
|
||||
|
||||
Complete guide to Supervised Fine-Tuning (SFT) with TRL for instruction tuning and task-specific fine-tuning.
|
||||
|
||||
## Overview
|
||||
|
||||
SFT trains models on input-output pairs to minimize cross-entropy loss. Use for:
|
||||
- Instruction following
|
||||
- Task-specific fine-tuning
|
||||
- Chatbot training
|
||||
- Domain adaptation
|
||||
|
||||
## Dataset Formats
|
||||
|
||||
### Format 1: Prompt-Completion
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"completion": "The capital of France is Paris."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Format 2: Conversational (ChatML)
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
{"role": "assistant", "content": "Python is a programming language."}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Format 3: Text-only
|
||||
|
||||
```json
|
||||
[
|
||||
{"text": "User: Hello\nAssistant: Hi! How can I help?"}
|
||||
]
|
||||
```
|
||||
|
||||
## Basic Training
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure
|
||||
config = SFTConfig(
|
||||
output_dir="Qwen2.5-SFT",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=2e-5,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Chat Templates
|
||||
|
||||
Apply chat templates automatically:
|
||||
|
||||
```python
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset, # Messages format
|
||||
tokenizer=tokenizer
|
||||
# Chat template applied automatically
|
||||
)
|
||||
```
|
||||
|
||||
Or manually:
|
||||
```python
|
||||
def format_chat(example):
|
||||
messages = example["messages"]
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
return {"text": text}
|
||||
|
||||
dataset = dataset.map(format_chat)
|
||||
```
|
||||
|
||||
## Packing for Efficiency
|
||||
|
||||
Pack multiple sequences into one to maximize GPU utilization:
|
||||
|
||||
```python
|
||||
config = SFTConfig(
|
||||
packing=True, # Enable packing
|
||||
max_seq_length=2048,
|
||||
dataset_text_field="text"
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**: 2-3× faster training
|
||||
**Trade-off**: Slightly more complex batching
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
```bash
|
||||
accelerate launch --num_processes 4 train_sft.py
|
||||
```
|
||||
|
||||
Or with config:
|
||||
```python
|
||||
config = SFTConfig(
|
||||
output_dir="model-sft",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=1
|
||||
)
|
||||
```
|
||||
|
||||
## LoRA Fine-Tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config # Add LoRA
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
| Model Size | Learning Rate | Batch Size | Epochs |
|
||||
|------------|---------------|------------|--------|
|
||||
| <1B | 5e-5 | 8-16 | 1-3 |
|
||||
| 1-7B | 2e-5 | 4-8 | 1-2 |
|
||||
| 7-13B | 1e-5 | 2-4 | 1 |
|
||||
| 13B+ | 5e-6 | 1-2 | 1 |
|
||||
|
||||
## References
|
||||
|
||||
- TRL docs: https://huggingface.co/docs/trl/sft_trainer
|
||||
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
"""
|
||||
Basic GRPO Training Template
|
||||
=============================
|
||||
|
||||
A minimal, production-ready template for GRPO training with TRL.
|
||||
Adapt this for your specific task by modifying:
|
||||
1. Dataset loading (get_dataset function)
|
||||
2. Reward functions (reward_*_func)
|
||||
3. System prompt (SYSTEM_PROMPT)
|
||||
4. Hyperparameters (GRPOConfig)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import re
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
# ==================== CONFIGURATION ====================
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
OUTPUT_DIR = "outputs/grpo-model"
|
||||
MAX_PROMPT_LENGTH = 256
|
||||
MAX_COMPLETION_LENGTH = 512
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
# ==================== DATASET ====================
|
||||
|
||||
def get_dataset(split="train"):
|
||||
"""
|
||||
Load and prepare your dataset.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content
|
||||
- 'answer': str (ground truth, optional)
|
||||
"""
|
||||
# Example: GSM8K math dataset
|
||||
data = load_dataset('openai/gsm8k', 'main')[split]
|
||||
|
||||
def process_example(x):
|
||||
# Extract ground truth answer
|
||||
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
|
||||
|
||||
return {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
return data.map(process_example)
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
def extract_xml_tag(text: str, tag: str) -> str:
|
||||
"""Extract content between XML tags."""
|
||||
pattern = f'<{tag}>(.*?)</{tag}>'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract the final answer from structured output."""
|
||||
return extract_xml_tag(text, 'answer')
|
||||
|
||||
# ==================== REWARD FUNCTIONS ====================
|
||||
|
||||
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
||||
"""
|
||||
Reward correct answers.
|
||||
Weight: 2.0 (highest priority)
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Reward proper XML format.
|
||||
Weight: 0.5
|
||||
"""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
|
||||
|
||||
def incremental_format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Incremental reward for partial format compliance.
|
||||
Weight: up to 0.5
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.125
|
||||
if '</reasoning>' in r:
|
||||
score += 0.125
|
||||
if '<answer>' in r:
|
||||
score += 0.125
|
||||
if '</answer>' in r:
|
||||
score += 0.125
|
||||
|
||||
# Penalize extra content after closing tag
|
||||
if '</answer>' in r:
|
||||
extra = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra) * 0.001
|
||||
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== MODEL SETUP ====================
|
||||
|
||||
def setup_model_and_tokenizer():
|
||||
"""Load model and tokenizer with optimizations."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def get_peft_config():
|
||||
"""LoRA configuration for parameter-efficient training."""
|
||||
return LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# ==================== TRAINING ====================
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
|
||||
# Load data
|
||||
print("Loading dataset...")
|
||||
dataset = get_dataset()
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Setup model
|
||||
print("Loading model...")
|
||||
model, tokenizer = setup_model_and_tokenizer()
|
||||
|
||||
# Training configuration
|
||||
training_args = GRPOConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
run_name="grpo-training",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
|
||||
# GRPO specific
|
||||
num_generations=8,
|
||||
max_prompt_length=MAX_PROMPT_LENGTH,
|
||||
max_completion_length=MAX_COMPLETION_LENGTH,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
|
||||
# Optimization
|
||||
bf16=True,
|
||||
optim="adamw_8bit",
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Change to "none" to disable logging
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward_func,
|
||||
format_reward_func,
|
||||
correctness_reward_func,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=get_peft_config(),
|
||||
)
|
||||
|
||||
# Train
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print(f"Saving model to {OUTPUT_DIR}/final")
|
||||
trainer.save_model(f"{OUTPUT_DIR}/final")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
84
optional-skills/mlops/training/unsloth/SKILL.md
Normal file
84
optional-skills/mlops/training/unsloth/SKILL.md
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
---
|
||||
name: unsloth
|
||||
description: "Unsloth: 2-5x faster LoRA/QLoRA fine-tuning, less VRAM."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [unsloth, torch, transformers, trl, datasets, peft]
|
||||
platforms: [linux, macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Fine-Tuning, Unsloth, Fast Training, LoRA, QLoRA, Memory-Efficient, Optimization, Llama, Mistral, Gemma, Qwen]
|
||||
|
||||
---
|
||||
|
||||
# Unsloth Skill
|
||||
|
||||
Comprehensive assistance with unsloth development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with unsloth
|
||||
- Asking about unsloth features or APIs
|
||||
- Implementing unsloth solutions
|
||||
- Debugging unsloth code
|
||||
- Learning unsloth best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
*Quick reference patterns will be added as you use the skill.*
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **llms-txt.md** - Llms-Txt documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
<!-- Trigger re-upload 1763621536 -->
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# Unsloth Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Llms-Txt
|
||||
**File:** `llms-txt.md`
|
||||
**Pages:** 136
|
||||
16799
optional-skills/mlops/training/unsloth/references/llms-full.md
Normal file
16799
optional-skills/mlops/training/unsloth/references/llms-full.md
Normal file
File diff suppressed because one or more lines are too long
12044
optional-skills/mlops/training/unsloth/references/llms-txt.md
Normal file
12044
optional-skills/mlops/training/unsloth/references/llms-txt.md
Normal file
File diff suppressed because one or more lines are too long
82
optional-skills/mlops/training/unsloth/references/llms.md
Normal file
82
optional-skills/mlops/training/unsloth/references/llms.md
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
# Unsloth Documentation
|
||||
|
||||
## Unsloth Documentation
|
||||
|
||||
- [Unsloth Docs](/get-started/unsloth-docs.md): Train your own model with Unsloth, an open-source framework for LLM fine-tuning and reinforcement learning.
|
||||
- [Beginner? Start here!](/get-started/beginner-start-here.md)
|
||||
- [Unsloth Requirements](/get-started/beginner-start-here/unsloth-requirements.md): Here are Unsloth's requirements including system and GPU VRAM requirements.
|
||||
- [FAQ + Is Fine-tuning Right For Me?](/get-started/beginner-start-here/faq-+-is-fine-tuning-right-for-me.md): If you're stuck on if fine-tuning is right for you, see here! Learn about fine-tuning misconceptions, how it compared to RAG and more:
|
||||
- [Unsloth Notebooks](/get-started/unsloth-notebooks.md): Explore our catalog of Unsloth notebooks:
|
||||
- [All Our Models](/get-started/all-our-models.md)
|
||||
- [Install & Update](/get-started/install-and-update.md): Learn to install Unsloth locally or online.
|
||||
- [Updating](/get-started/install-and-update/updating.md): To update or use an old version of Unsloth, follow the steps below:
|
||||
- [Pip Install](/get-started/install-and-update/pip-install.md): To install Unsloth locally via Pip, follow the steps below:
|
||||
- [Docker](/get-started/install-and-update/docker.md): Install Unsloth using our official Docker container
|
||||
- [Windows Installation](/get-started/install-and-update/windows-installation.md): See how to install Unsloth on Windows with or without WSL.
|
||||
- [AMD](/get-started/install-and-update/amd.md): Fine-tune with Unsloth on AMD GPUs.
|
||||
- [Conda Install](/get-started/install-and-update/conda-install.md): To install Unsloth locally on Conda, follow the steps below:
|
||||
- [Google Colab](/get-started/install-and-update/google-colab.md): To install and run Unsloth on Google Colab, follow the steps below:
|
||||
- [Fine-tuning LLMs Guide](/get-started/fine-tuning-llms-guide.md): Learn all the basics and best practices of fine-tuning. Beginner-friendly.
|
||||
- [What Model Should I Use?](/get-started/fine-tuning-llms-guide/what-model-should-i-use.md)
|
||||
- [Datasets Guide](/get-started/fine-tuning-llms-guide/datasets-guide.md): Learn how to create & prepare a dataset for fine-tuning.
|
||||
- [LoRA Hyperparameters Guide](/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide.md): Optimal lora rank. alpha, number of epochs, batch size & gradient accumulation, QLoRA vs LoRA, target modules and more!
|
||||
- [Tutorial: How to Finetune Llama-3 and Use In Ollama](/get-started/fine-tuning-llms-guide/tutorial-how-to-finetune-llama-3-and-use-in-ollama.md): Beginner's Guide for creating a customized personal assistant (like ChatGPT) to run locally on Ollama
|
||||
- [Reinforcement Learning (RL) Guide](/get-started/reinforcement-learning-rl-guide.md): Learn all about Reinforcement Learning (RL) and how to train your own DeepSeek-R1 reasoning model with Unsloth using GRPO. A complete guide from beginner to advanced.
|
||||
- [Tutorial: Train your own Reasoning model with GRPO](/get-started/reinforcement-learning-rl-guide/tutorial-train-your-own-reasoning-model-with-grpo.md): Beginner's Guide to transforming a model like Llama 3.1 (8B) into a reasoning model by using Unsloth and GRPO.
|
||||
- [Advanced RL Documentation](/get-started/reinforcement-learning-rl-guide/advanced-rl-documentation.md): Advanced documentation settings when using Unsloth with GRPO.
|
||||
- [Memory Efficient RL](/get-started/reinforcement-learning-rl-guide/memory-efficient-rl.md)
|
||||
- [RL Reward Hacking](/get-started/reinforcement-learning-rl-guide/rl-reward-hacking.md): Learn what is Reward Hacking in Reinforcement Learning and how to counter it.
|
||||
- [GSPO Reinforcement Learning](/get-started/reinforcement-learning-rl-guide/gspo-reinforcement-learning.md): Train with GSPO (Group Sequence Policy Optimization) RL in Unsloth.
|
||||
- [Reinforcement Learning - DPO, ORPO & KTO](/get-started/reinforcement-learning-rl-guide/reinforcement-learning-dpo-orpo-and-kto.md): To use the reward modelling functions for DPO, GRPO, ORPO or KTO with Unsloth, follow the steps below:
|
||||
- [DeepSeek-OCR: How to Run & Fine-tune](/new/deepseek-ocr-how-to-run-and-fine-tune.md): Guide on how to run and fine-tune DeepSeek-OCR locally.
|
||||
- [How to Fine-tune LLMs with Unsloth & Docker](/new/how-to-fine-tune-llms-with-unsloth-and-docker.md): Learn how to fine-tune LLMs or do Reinforcement Learning (RL) with Unsloth's Docker image.
|
||||
- [Vision Reinforcement Learning (VLM RL)](/new/vision-reinforcement-learning-vlm-rl.md): Train Vision/multimodal models via GRPO and RL with Unsloth!
|
||||
- [gpt-oss Reinforcement Learning](/new/gpt-oss-reinforcement-learning.md)
|
||||
- [Tutorial: How to Train gpt-oss with RL](/new/gpt-oss-reinforcement-learning/tutorial-how-to-train-gpt-oss-with-rl.md): Learn to train OpenAI gpt-oss with GRPO to autonomously beat 2048 locally or on Colab.
|
||||
- [Unsloth Dynamic GGUFs on Aider Polyglot](/new/unsloth-dynamic-ggufs-on-aider-polyglot.md): Performance of Unsloth Dynamic GGUFs on Aider Polyglot Benchmarks
|
||||
- [Qwen3-VL: How to Run & Fine-tune](/models/qwen3-vl-how-to-run-and-fine-tune.md): Learn to fine-tune and run Qwen3-VL locally with Unsloth.
|
||||
- [gpt-oss: How to Run & Fine-tune](/models/gpt-oss-how-to-run-and-fine-tune.md): Run & fine-tune OpenAI's new open-source models!
|
||||
- [Tutorial: How to Fine-tune gpt-oss](/models/gpt-oss-how-to-run-and-fine-tune/tutorial-how-to-fine-tune-gpt-oss.md): Learn step-by-step how to train OpenAI gpt-oss locally with Unsloth.
|
||||
- [Long Context gpt-oss Training](/models/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training.md)
|
||||
- [GLM-4.6: How to Run Locally](/models/glm-4.6-how-to-run-locally.md): A guide on how to run Z.ai's new GLM-4.6 model on your own local device!
|
||||
- [IBM Granite 4.0](/models/ibm-granite-4.0.md): How to run IBM Granite-4.0 with Unsloth GGUFs on llama.cpp, Ollama and how to fine-tune!
|
||||
- [DeepSeek-V3.1: How to Run Locally](/models/deepseek-v3.1-how-to-run-locally.md): A guide on how to run DeepSeek-V3.1 and Terminus on your own local device!
|
||||
- [Qwen3-Coder: How to Run Locally](/models/qwen3-coder-how-to-run-locally.md): Run Qwen3-Coder-30B-A3B-Instruct and 480B-A35B locally with Unsloth Dynamic quants.
|
||||
- [Gemma 3: How to Run & Fine-tune](/models/gemma-3-how-to-run-and-fine-tune.md): How to run Gemma 3 effectively with our GGUFs on llama.cpp, Ollama, Open WebUI and how to fine-tune with Unsloth!
|
||||
- [Gemma 3n: How to Run & Fine-tune](/models/gemma-3-how-to-run-and-fine-tune/gemma-3n-how-to-run-and-fine-tune.md): Run Google's new Gemma 3n locally with Dynamic GGUFs on llama.cpp, Ollama, Open WebUI and fine-tune with Unsloth!
|
||||
- [Qwen3: How to Run & Fine-tune](/models/qwen3-how-to-run-and-fine-tune.md): Learn to run & fine-tune Qwen3 locally with Unsloth + our Dynamic 2.0 quants
|
||||
- [Qwen3-2507](/models/qwen3-how-to-run-and-fine-tune/qwen3-2507.md): Run Qwen3-30B-A3B-2507 and 235B-A22B Thinking and Instruct versions locally on your device!
|
||||
- [Tutorials: How To Fine-tune & Run LLMs](/models/tutorials-how-to-fine-tune-and-run-llms.md): Learn how to run and fine-tune models for optimal performance 100% locally with Unsloth.
|
||||
- [DeepSeek-R1-0528: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-0528-how-to-run-locally.md): A guide on how to run DeepSeek-R1-0528 including Qwen3 on your own local device!
|
||||
- [Magistral: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/magistral-how-to-run-and-fine-tune.md): Meet Magistral - Mistral's new reasoning models.
|
||||
- [Llama 4: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/llama-4-how-to-run-and-fine-tune.md): How to run Llama 4 locally using our dynamic GGUFs which recovers accuracy compared to standard quantization.
|
||||
- [Kimi K2: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/kimi-k2-how-to-run-locally.md): Guide on running Kimi K2 and Kimi-K2-Instruct-0905 on your own local device!
|
||||
- [Grok 2](/models/tutorials-how-to-fine-tune-and-run-llms/grok-2.md): Run xAI's Grok 2 model locally!
|
||||
- [Devstral: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/devstral-how-to-run-and-fine-tune.md): Run and fine-tune Mistral Devstral 1.1, including Small-2507 and 2505.
|
||||
- [DeepSeek-V3-0324: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-v3-0324-how-to-run-locally.md): How to run DeepSeek-V3-0324 locally using our dynamic quants which recovers accuracy
|
||||
- [DeepSeek-R1: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-how-to-run-locally.md): A guide on how you can run our 1.58-bit Dynamic Quants for DeepSeek-R1 using llama.cpp.
|
||||
- [DeepSeek-R1 Dynamic 1.58-bit](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-how-to-run-locally/deepseek-r1-dynamic-1.58-bit.md): See performance comparison tables for Unsloth's Dynamic GGUF Quants vs Standard IMatrix Quants.
|
||||
- [QwQ-32B: How to Run effectively](/models/tutorials-how-to-fine-tune-and-run-llms/qwq-32b-how-to-run-effectively.md): How to run QwQ-32B effectively with our bug fixes and without endless generations + GGUFs.
|
||||
- [Phi-4 Reasoning: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/phi-4-reasoning-how-to-run-and-fine-tune.md): Learn to run & fine-tune Phi-4 reasoning models locally with Unsloth + our Dynamic 2.0 quants
|
||||
- [Running & Saving Models](/basics/running-and-saving-models.md): Learn how to save your finetuned model so you can run it in your favorite inference engine.
|
||||
- [Saving to GGUF](/basics/running-and-saving-models/saving-to-gguf.md): Saving models to 16bit for GGUF so you can use it for Ollama, Jan AI, Open WebUI and more!
|
||||
- [Saving to Ollama](/basics/running-and-saving-models/saving-to-ollama.md)
|
||||
- [Saving to vLLM for deployment](/basics/running-and-saving-models/saving-to-vllm-for-deployment.md): Saving models to 16bit for vLLM deployment and serving
|
||||
- [Saving to SGLang for deployment](/basics/running-and-saving-models/saving-to-sglang-for-deployment.md): Saving models to 16bit for SGLang for deployment and serving
|
||||
- [Unsloth Inference](/basics/running-and-saving-models/unsloth-inference.md): Learn how to run your finetuned model with Unsloth's faster inference.
|
||||
- [Troubleshooting Inference](/basics/running-and-saving-models/troubleshooting-inference.md): If you're experiencing issues when running or saving your model.
|
||||
- [vLLM Engine Arguments](/basics/running-and-saving-models/vllm-engine-arguments.md)
|
||||
- [LoRA Hot Swapping Guide](/basics/running-and-saving-models/lora-hot-swapping-guide.md)
|
||||
- [Text-to-Speech (TTS) Fine-tuning](/basics/text-to-speech-tts-fine-tuning.md): Learn how to fine-tune TTS & STT voice models with Unsloth.
|
||||
- [Unsloth Dynamic 2.0 GGUFs](/basics/unsloth-dynamic-2.0-ggufs.md): A big new upgrade to our Dynamic Quants!
|
||||
- [Vision Fine-tuning](/basics/vision-fine-tuning.md): Learn how to fine-tune vision/multimodal LLMs with Unsloth
|
||||
- [Fine-tuning LLMs with NVIDIA DGX Spark and Unsloth](/basics/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth.md): Tutorial on how to fine-tune and do reinforcement learning (RL) with OpenAI gpt-oss on NVIDIA DGX Spark.
|
||||
- [Fine-tuning LLMs with Blackwell, RTX 50 series & Unsloth](/basics/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth.md): Learn how to fine-tune LLMs on NVIDIA's Blackwell RTX 50 series and B200 GPUs with our step-by-step guide.
|
||||
- [Multi-GPU Training with Unsloth](/basics/multi-gpu-training-with-unsloth.md): Learn how to fine-tune LLMs on multiple GPUs and parallelism with Unsloth.
|
||||
- [Finetuning from Last Checkpoint](/basics/finetuning-from-last-checkpoint.md): Checkpointing allows you to save your finetuning progress so you can pause it and then continue.
|
||||
- [Troubleshooting & FAQs](/basics/troubleshooting-and-faqs.md): Tips to solve issues, and frequently asked questions.
|
||||
- [Chat Templates](/basics/chat-templates.md): Learn the fundamentals and customization options of chat templates, including Conversational, ChatML, ShareGPT, Alpaca formats, and more!
|
||||
- [Quantization-Aware Training (QAT)](/basics/quantization-aware-training-qat.md): Quantize models to 4-bit with Unsloth and PyTorch to recover accuracy.
|
||||
- [Unsloth Environment Flags](/basics/unsloth-environment-flags.md): Advanced flags which might be useful if you see breaking finetunes, or you want to turn stuff off.
|
||||
- [Continued Pretraining](/basics/continued-pretraining.md): AKA as Continued Finetuning. Unsloth allows you to continually pretrain so a model can learn a new language.
|
||||
- [Unsloth Benchmarks](/basics/unsloth-benchmarks.md): Unsloth recorded benchmarks on NVIDIA GPUs.
|
||||
Loading…
Add table
Add a link
Reference in a new issue