mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 01:31:41 +00:00
feat(gateway): skill-aware slash commands, paginated /commands, Telegram 100-cap (#3934)
* feat(gateway): skill-aware slash commands, paginated /commands, Telegram 100-cap Map active skills to Telegram's slash command menu so users can discover and invoke skills directly. Three changes: 1. Telegram menu now includes active skill commands alongside built-in commands, capped at 100 entries (Telegram Bot API limit). Overflow commands remain callable but hidden from the picker. Logged at startup when cap is hit. 2. New /commands [page] gateway command for paginated browsing of all commands + skills. /help now shows first 10 skill commands and points to /commands for the full list. 3. When a user types a slash command that matches a disabled or uninstalled skill, they get actionable guidance: - Disabled: 'Enable it with: hermes skills config' - Optional (not installed): 'Install with: hermes skills install official/<path>' Built on ideas from PR #3921 by @kshitijk4poor. * chore: move 21 niche skills to optional-skills Move specialized/niche skills from built-in (skills/) to optional (optional-skills/) to reduce the default skill count. Users can install them with: hermes skills install official/<category>/<name> Moved skills (21): - mlops: accelerate, chroma, faiss, flash-attention, hermes-atropos-environments, huggingface-tokenizers, instructor, lambda-labs, llava, nemo-curator, pinecone, pytorch-lightning, qdrant, saelens, simpo, slime, tensorrt-llm, torchtitan - research: domain-intel, duckduckgo-search - devops: inference-sh cli Built-in skills: 96 → 75 Optional skills: 22 → 43 * fix: only include repo built-in skills in Telegram menu, not user-installed User-installed skills (from hub or manually added) stay accessible via /skills and by typing the command directly, but don't get registered in the Telegram slash command picker. Only skills whose SKILL.md is under the repo's skills/ directory are included in the menu. This keeps the Telegram menu focused on the curated built-in set while user-installed skills remain discoverable through /skills and /commands.
This commit is contained in:
parent
97d6813f51
commit
5ceed021dc
73 changed files with 163 additions and 4 deletions
|
|
@ -1,743 +0,0 @@
|
|||
---
|
||||
name: instructor
|
||||
description: Extract structured data from LLM responses with Pydantic validation, retry failed extractions automatically, parse complex JSON with type safety, and stream partial results with Instructor - battle-tested structured output library
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [instructor, pydantic, openai, anthropic]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Prompt Engineering, Instructor, Structured Output, Pydantic, Data Extraction, JSON Parsing, Type Safety, Validation, Streaming, OpenAI, Anthropic]
|
||||
|
||||
---
|
||||
|
||||
# Instructor: Structured LLM Outputs
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Instructor when you need to:
|
||||
- **Extract structured data** from LLM responses reliably
|
||||
- **Validate outputs** against Pydantic schemas automatically
|
||||
- **Retry failed extractions** with automatic error handling
|
||||
- **Parse complex JSON** with type safety and validation
|
||||
- **Stream partial results** for real-time processing
|
||||
- **Support multiple LLM providers** with consistent API
|
||||
|
||||
**GitHub Stars**: 15,000+ | **Battle-tested**: 100,000+ developers
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base installation
|
||||
pip install instructor
|
||||
|
||||
# With specific providers
|
||||
pip install "instructor[anthropic]" # Anthropic Claude
|
||||
pip install "instructor[openai]" # OpenAI
|
||||
pip install "instructor[all]" # All providers
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Extract User Data
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from anthropic import Anthropic
|
||||
|
||||
# Define output structure
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
# Create instructor client
|
||||
client = instructor.from_anthropic(Anthropic())
|
||||
|
||||
# Extract structured data
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "John Doe is 30 years old. His email is john@example.com"
|
||||
}],
|
||||
response_model=User
|
||||
)
|
||||
|
||||
print(user.name) # "John Doe"
|
||||
print(user.age) # 30
|
||||
print(user.email) # "john@example.com"
|
||||
```
|
||||
|
||||
### With OpenAI
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.from_openai(OpenAI())
|
||||
|
||||
user = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
response_model=User,
|
||||
messages=[{"role": "user", "content": "Extract: Alice, 25, alice@email.com"}]
|
||||
)
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Response Models (Pydantic)
|
||||
|
||||
Response models define the structure and validation rules for LLM outputs.
|
||||
|
||||
#### Basic Model
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Article(BaseModel):
|
||||
title: str = Field(description="Article title")
|
||||
author: str = Field(description="Author name")
|
||||
word_count: int = Field(description="Number of words", gt=0)
|
||||
tags: list[str] = Field(description="List of relevant tags")
|
||||
|
||||
article = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Analyze this article: [article text]"
|
||||
}],
|
||||
response_model=Article
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Type safety with Python type hints
|
||||
- Automatic validation (word_count > 0)
|
||||
- Self-documenting with Field descriptions
|
||||
- IDE autocomplete support
|
||||
|
||||
#### Nested Models
|
||||
|
||||
```python
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
country: str
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
address: Address # Nested model
|
||||
|
||||
person = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "John lives at 123 Main St, Boston, USA"
|
||||
}],
|
||||
response_model=Person
|
||||
)
|
||||
|
||||
print(person.address.city) # "Boston"
|
||||
```
|
||||
|
||||
#### Optional Fields
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
discount: Optional[float] = None # Optional
|
||||
description: str = Field(default="No description") # Default value
|
||||
|
||||
# LLM doesn't need to provide discount or description
|
||||
```
|
||||
|
||||
#### Enums for Constraints
|
||||
|
||||
```python
|
||||
from enum import Enum
|
||||
|
||||
class Sentiment(str, Enum):
|
||||
POSITIVE = "positive"
|
||||
NEGATIVE = "negative"
|
||||
NEUTRAL = "neutral"
|
||||
|
||||
class Review(BaseModel):
|
||||
text: str
|
||||
sentiment: Sentiment # Only these 3 values allowed
|
||||
|
||||
review = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "This product is amazing!"
|
||||
}],
|
||||
response_model=Review
|
||||
)
|
||||
|
||||
print(review.sentiment) # Sentiment.POSITIVE
|
||||
```
|
||||
|
||||
### 2. Validation
|
||||
|
||||
Pydantic validates LLM outputs automatically. If validation fails, Instructor retries.
|
||||
|
||||
#### Built-in Validators
|
||||
|
||||
```python
|
||||
from pydantic import Field, EmailStr, HttpUrl
|
||||
|
||||
class Contact(BaseModel):
|
||||
name: str = Field(min_length=2, max_length=100)
|
||||
age: int = Field(ge=0, le=120) # 0 <= age <= 120
|
||||
email: EmailStr # Validates email format
|
||||
website: HttpUrl # Validates URL format
|
||||
|
||||
# If LLM provides invalid data, Instructor retries automatically
|
||||
```
|
||||
|
||||
#### Custom Validators
|
||||
|
||||
```python
|
||||
from pydantic import field_validator
|
||||
|
||||
class Event(BaseModel):
|
||||
name: str
|
||||
date: str
|
||||
attendees: int
|
||||
|
||||
@field_validator('date')
|
||||
def validate_date(cls, v):
|
||||
"""Ensure date is in YYYY-MM-DD format."""
|
||||
import re
|
||||
if not re.match(r'\d{4}-\d{2}-\d{2}', v):
|
||||
raise ValueError('Date must be YYYY-MM-DD format')
|
||||
return v
|
||||
|
||||
@field_validator('attendees')
|
||||
def validate_attendees(cls, v):
|
||||
"""Ensure positive attendees."""
|
||||
if v < 1:
|
||||
raise ValueError('Must have at least 1 attendee')
|
||||
return v
|
||||
```
|
||||
|
||||
#### Model-Level Validation
|
||||
|
||||
```python
|
||||
from pydantic import model_validator
|
||||
|
||||
class DateRange(BaseModel):
|
||||
start_date: str
|
||||
end_date: str
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_dates(self):
|
||||
"""Ensure end_date is after start_date."""
|
||||
from datetime import datetime
|
||||
start = datetime.strptime(self.start_date, '%Y-%m-%d')
|
||||
end = datetime.strptime(self.end_date, '%Y-%m-%d')
|
||||
|
||||
if end < start:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return self
|
||||
```
|
||||
|
||||
### 3. Automatic Retrying
|
||||
|
||||
Instructor retries automatically when validation fails, providing error feedback to the LLM.
|
||||
|
||||
```python
|
||||
# Retries up to 3 times if validation fails
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Extract user from: John, age unknown"
|
||||
}],
|
||||
response_model=User,
|
||||
max_retries=3 # Default is 3
|
||||
)
|
||||
|
||||
# If age can't be extracted, Instructor tells the LLM:
|
||||
# "Validation error: age - field required"
|
||||
# LLM tries again with better extraction
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. LLM generates output
|
||||
2. Pydantic validates
|
||||
3. If invalid: Error message sent back to LLM
|
||||
4. LLM tries again with error feedback
|
||||
5. Repeats up to max_retries
|
||||
|
||||
### 4. Streaming
|
||||
|
||||
Stream partial results for real-time processing.
|
||||
|
||||
#### Streaming Partial Objects
|
||||
|
||||
```python
|
||||
from instructor import Partial
|
||||
|
||||
class Story(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
tags: list[str]
|
||||
|
||||
# Stream partial updates as LLM generates
|
||||
for partial_story in client.messages.create_partial(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Write a short sci-fi story"
|
||||
}],
|
||||
response_model=Story
|
||||
):
|
||||
print(f"Title: {partial_story.title}")
|
||||
print(f"Content so far: {partial_story.content[:100]}...")
|
||||
# Update UI in real-time
|
||||
```
|
||||
|
||||
#### Streaming Iterables
|
||||
|
||||
```python
|
||||
class Task(BaseModel):
|
||||
title: str
|
||||
priority: str
|
||||
|
||||
# Stream list items as they're generated
|
||||
tasks = client.messages.create_iterable(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Generate 10 project tasks"
|
||||
}],
|
||||
response_model=Task
|
||||
)
|
||||
|
||||
for task in tasks:
|
||||
print(f"- {task.title} ({task.priority})")
|
||||
# Process each task as it arrives
|
||||
```
|
||||
|
||||
## Provider Configuration
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from anthropic import Anthropic
|
||||
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(api_key="your-api-key")
|
||||
)
|
||||
|
||||
# Use with Claude models
|
||||
response = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=YourModel
|
||||
)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.from_openai(
|
||||
OpenAI(api_key="your-api-key")
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
response_model=YourModel,
|
||||
messages=[...]
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (Ollama)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Point to local Ollama server
|
||||
client = instructor.from_openai(
|
||||
OpenAI(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="ollama" # Required but ignored
|
||||
),
|
||||
mode=instructor.Mode.JSON
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="llama3.1",
|
||||
response_model=YourModel,
|
||||
messages=[...]
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: Data Extraction from Text
|
||||
|
||||
```python
|
||||
class CompanyInfo(BaseModel):
|
||||
name: str
|
||||
founded_year: int
|
||||
industry: str
|
||||
employees: int
|
||||
headquarters: str
|
||||
|
||||
text = """
|
||||
Tesla, Inc. was founded in 2003. It operates in the automotive and energy
|
||||
industry with approximately 140,000 employees. The company is headquartered
|
||||
in Austin, Texas.
|
||||
"""
|
||||
|
||||
company = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Extract company information from: {text}"
|
||||
}],
|
||||
response_model=CompanyInfo
|
||||
)
|
||||
```
|
||||
|
||||
### Pattern 2: Classification
|
||||
|
||||
```python
|
||||
class Category(str, Enum):
|
||||
TECHNOLOGY = "technology"
|
||||
FINANCE = "finance"
|
||||
HEALTHCARE = "healthcare"
|
||||
EDUCATION = "education"
|
||||
OTHER = "other"
|
||||
|
||||
class ArticleClassification(BaseModel):
|
||||
category: Category
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
keywords: list[str]
|
||||
|
||||
classification = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Classify this article: [article text]"
|
||||
}],
|
||||
response_model=ArticleClassification
|
||||
)
|
||||
```
|
||||
|
||||
### Pattern 3: Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
|
||||
class Organization(BaseModel):
|
||||
name: str
|
||||
industry: str
|
||||
|
||||
class Entities(BaseModel):
|
||||
people: list[Person]
|
||||
organizations: list[Organization]
|
||||
locations: list[str]
|
||||
|
||||
text = "Tim Cook, CEO of Apple, announced at the event in Cupertino..."
|
||||
|
||||
entities = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Extract all entities from: {text}"
|
||||
}],
|
||||
response_model=Entities
|
||||
)
|
||||
|
||||
for person in entities.people:
|
||||
print(f"{person.name} - {person.role}")
|
||||
```
|
||||
|
||||
### Pattern 4: Structured Analysis
|
||||
|
||||
```python
|
||||
class SentimentAnalysis(BaseModel):
|
||||
overall_sentiment: Sentiment
|
||||
positive_aspects: list[str]
|
||||
negative_aspects: list[str]
|
||||
suggestions: list[str]
|
||||
score: float = Field(ge=-1.0, le=1.0)
|
||||
|
||||
review = "The product works well but setup was confusing..."
|
||||
|
||||
analysis = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Analyze this review: {review}"
|
||||
}],
|
||||
response_model=SentimentAnalysis
|
||||
)
|
||||
```
|
||||
|
||||
### Pattern 5: Batch Processing
|
||||
|
||||
```python
|
||||
def extract_person(text: str) -> Person:
|
||||
return client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Extract person from: {text}"
|
||||
}],
|
||||
response_model=Person
|
||||
)
|
||||
|
||||
texts = [
|
||||
"John Doe is a 30-year-old engineer",
|
||||
"Jane Smith, 25, works in marketing",
|
||||
"Bob Johnson, age 40, software developer"
|
||||
]
|
||||
|
||||
people = [extract_person(text) for text in texts]
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Union Types
|
||||
|
||||
```python
|
||||
from typing import Union
|
||||
|
||||
class TextContent(BaseModel):
|
||||
type: str = "text"
|
||||
content: str
|
||||
|
||||
class ImageContent(BaseModel):
|
||||
type: str = "image"
|
||||
url: HttpUrl
|
||||
caption: str
|
||||
|
||||
class Post(BaseModel):
|
||||
title: str
|
||||
content: Union[TextContent, ImageContent] # Either type
|
||||
|
||||
# LLM chooses appropriate type based on content
|
||||
```
|
||||
|
||||
### Dynamic Models
|
||||
|
||||
```python
|
||||
from pydantic import create_model
|
||||
|
||||
# Create model at runtime
|
||||
DynamicUser = create_model(
|
||||
'User',
|
||||
name=(str, ...),
|
||||
age=(int, Field(ge=0)),
|
||||
email=(EmailStr, ...)
|
||||
)
|
||||
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=DynamicUser
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Modes
|
||||
|
||||
```python
|
||||
# For providers without native structured outputs
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(),
|
||||
mode=instructor.Mode.JSON # JSON mode
|
||||
)
|
||||
|
||||
# Available modes:
|
||||
# - Mode.ANTHROPIC_TOOLS (recommended for Claude)
|
||||
# - Mode.JSON (fallback)
|
||||
# - Mode.TOOLS (OpenAI tools)
|
||||
```
|
||||
|
||||
### Context Management
|
||||
|
||||
```python
|
||||
# Single-use client
|
||||
with instructor.from_anthropic(Anthropic()) as client:
|
||||
result = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=YourModel
|
||||
)
|
||||
# Client closed automatically
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Handling Validation Errors
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
try:
|
||||
user = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=User,
|
||||
max_retries=3
|
||||
)
|
||||
except ValidationError as e:
|
||||
print(f"Failed after retries: {e}")
|
||||
# Handle gracefully
|
||||
|
||||
except Exception as e:
|
||||
print(f"API error: {e}")
|
||||
```
|
||||
|
||||
### Custom Error Messages
|
||||
|
||||
```python
|
||||
class ValidatedUser(BaseModel):
|
||||
name: str = Field(description="Full name, 2-100 characters")
|
||||
age: int = Field(description="Age between 0 and 120", ge=0, le=120)
|
||||
email: EmailStr = Field(description="Valid email address")
|
||||
|
||||
class Config:
|
||||
# Custom error messages
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"email": "john@example.com"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Clear Field Descriptions
|
||||
|
||||
```python
|
||||
# ❌ Bad: Vague
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
|
||||
# ✅ Good: Descriptive
|
||||
class Product(BaseModel):
|
||||
name: str = Field(description="Product name from the text")
|
||||
price: float = Field(description="Price in USD, without currency symbol")
|
||||
```
|
||||
|
||||
### 2. Use Appropriate Validation
|
||||
|
||||
```python
|
||||
# ✅ Good: Constrain values
|
||||
class Rating(BaseModel):
|
||||
score: int = Field(ge=1, le=5, description="Rating from 1 to 5 stars")
|
||||
review: str = Field(min_length=10, description="Review text, at least 10 chars")
|
||||
```
|
||||
|
||||
### 3. Provide Examples in Prompts
|
||||
|
||||
```python
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": """Extract person info from: "John, 30, engineer"
|
||||
|
||||
Example format:
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 30,
|
||||
"occupation": "engineer"
|
||||
}"""
|
||||
}]
|
||||
```
|
||||
|
||||
### 4. Use Enums for Fixed Categories
|
||||
|
||||
```python
|
||||
# ✅ Good: Enum ensures valid values
|
||||
class Status(str, Enum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
class Application(BaseModel):
|
||||
status: Status # LLM must choose from enum
|
||||
```
|
||||
|
||||
### 5. Handle Missing Data Gracefully
|
||||
|
||||
```python
|
||||
class PartialData(BaseModel):
|
||||
required_field: str
|
||||
optional_field: Optional[str] = None
|
||||
default_field: str = "default_value"
|
||||
|
||||
# LLM only needs to provide required_field
|
||||
```
|
||||
|
||||
## Comparison to Alternatives
|
||||
|
||||
| Feature | Instructor | Manual JSON | LangChain | DSPy |
|
||||
|---------|------------|-------------|-----------|------|
|
||||
| Type Safety | ✅ Yes | ❌ No | ⚠️ Partial | ✅ Yes |
|
||||
| Auto Validation | ✅ Yes | ❌ No | ❌ No | ⚠️ Limited |
|
||||
| Auto Retry | ✅ Yes | ❌ No | ❌ No | ✅ Yes |
|
||||
| Streaming | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
|
||||
| Multi-Provider | ✅ Yes | ⚠️ Manual | ✅ Yes | ✅ Yes |
|
||||
| Learning Curve | Low | Low | Medium | High |
|
||||
|
||||
**When to choose Instructor:**
|
||||
- Need structured, validated outputs
|
||||
- Want type safety and IDE support
|
||||
- Require automatic retries
|
||||
- Building data extraction systems
|
||||
|
||||
**When to choose alternatives:**
|
||||
- DSPy: Need prompt optimization
|
||||
- LangChain: Building complex chains
|
||||
- Manual: Simple, one-off extractions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://python.useinstructor.com
|
||||
- **GitHub**: https://github.com/jxnl/instructor (15k+ stars)
|
||||
- **Cookbook**: https://python.useinstructor.com/examples
|
||||
- **Discord**: Community support available
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/validation.md` - Advanced validation patterns
|
||||
- `references/providers.md` - Provider-specific configuration
|
||||
- `references/examples.md` - Real-world use cases
|
||||
|
||||
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
# Real-World Examples
|
||||
|
||||
Practical examples of using Instructor for structured data extraction.
|
||||
|
||||
## Data Extraction
|
||||
|
||||
```python
|
||||
class CompanyInfo(BaseModel):
|
||||
name: str
|
||||
founded: int
|
||||
industry: str
|
||||
employees: int
|
||||
|
||||
text = "Apple was founded in 1976 in the technology industry with 164,000 employees."
|
||||
|
||||
company = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": f"Extract: {text}"}],
|
||||
response_model=CompanyInfo
|
||||
)
|
||||
```
|
||||
|
||||
## Classification
|
||||
|
||||
```python
|
||||
class Sentiment(str, Enum):
|
||||
POSITIVE = "positive"
|
||||
NEGATIVE = "negative"
|
||||
NEUTRAL = "neutral"
|
||||
|
||||
class Review(BaseModel):
|
||||
sentiment: Sentiment
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
|
||||
review = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "This product is amazing!"}],
|
||||
response_model=Review
|
||||
)
|
||||
```
|
||||
|
||||
## Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
role: str
|
||||
|
||||
class Entities(BaseModel):
|
||||
people: list[Person]
|
||||
organizations: list[str]
|
||||
locations: list[str]
|
||||
|
||||
entities = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Tim Cook, CEO of Apple, spoke in Cupertino..."}],
|
||||
response_model=Entities
|
||||
)
|
||||
```
|
||||
|
||||
## Structured Analysis
|
||||
|
||||
```python
|
||||
class Analysis(BaseModel):
|
||||
summary: str
|
||||
key_points: list[str]
|
||||
sentiment: Sentiment
|
||||
actionable_items: list[str]
|
||||
|
||||
analysis = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Analyze: [long text]"}],
|
||||
response_model=Analysis
|
||||
)
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
```python
|
||||
texts = ["text1", "text2", "text3"]
|
||||
results = [
|
||||
client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": text}],
|
||||
response_model=YourModel
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
```
|
||||
|
||||
## Streaming
|
||||
|
||||
```python
|
||||
for partial in client.messages.create_partial(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Generate report..."}],
|
||||
response_model=Report
|
||||
):
|
||||
print(f"Progress: {partial.title}")
|
||||
# Update UI in real-time
|
||||
```
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
# Provider Configuration
|
||||
|
||||
Guide to using Instructor with different LLM providers.
|
||||
|
||||
## Anthropic Claude
|
||||
|
||||
```python
|
||||
import instructor
|
||||
from anthropic import Anthropic
|
||||
|
||||
# Basic setup
|
||||
client = instructor.from_anthropic(Anthropic())
|
||||
|
||||
# With API key
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(api_key="your-api-key")
|
||||
)
|
||||
|
||||
# Recommended mode
|
||||
client = instructor.from_anthropic(
|
||||
Anthropic(),
|
||||
mode=instructor.Mode.ANTHROPIC_TOOLS
|
||||
)
|
||||
|
||||
# Usage
|
||||
result = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "..."}],
|
||||
response_model=YourModel
|
||||
)
|
||||
```
|
||||
|
||||
## OpenAI
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = instructor.from_openai(OpenAI())
|
||||
|
||||
result = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
response_model=YourModel,
|
||||
messages=[{"role": "user", "content": "..."}]
|
||||
)
|
||||
```
|
||||
|
||||
## Local Models (Ollama)
|
||||
|
||||
```python
|
||||
client = instructor.from_openai(
|
||||
OpenAI(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key="ollama"
|
||||
),
|
||||
mode=instructor.Mode.JSON
|
||||
)
|
||||
|
||||
result = client.chat.completions.create(
|
||||
model="llama3.1",
|
||||
response_model=YourModel,
|
||||
messages=[...]
|
||||
)
|
||||
```
|
||||
|
||||
## Modes
|
||||
|
||||
- `Mode.ANTHROPIC_TOOLS`: Recommended for Claude
|
||||
- `Mode.TOOLS`: OpenAI function calling
|
||||
- `Mode.JSON`: Fallback for unsupported providers
|
||||
|
|
@ -1,606 +0,0 @@
|
|||
# Advanced Validation Patterns
|
||||
|
||||
Complete guide to validation in Instructor using Pydantic.
|
||||
|
||||
## Table of Contents
|
||||
- Built-in Validators
|
||||
- Custom Field Validators
|
||||
- Model-Level Validation
|
||||
- Complex Validation Patterns
|
||||
- Error Handling
|
||||
|
||||
## Built-in Validators
|
||||
|
||||
### Numeric Constraints
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Product(BaseModel):
|
||||
price: float = Field(gt=0, description="Price must be positive")
|
||||
discount: float = Field(ge=0, le=100, description="Discount 0-100%")
|
||||
quantity: int = Field(ge=1, description="At least 1 item")
|
||||
rating: float = Field(ge=0.0, le=5.0, description="Rating 0-5 stars")
|
||||
|
||||
# If LLM provides invalid values, automatic retry with error feedback
|
||||
```
|
||||
|
||||
**Available constraints:**
|
||||
- `gt`: Greater than
|
||||
- `ge`: Greater than or equal
|
||||
- `lt`: Less than
|
||||
- `le`: Less than or equal
|
||||
- `multiple_of`: Must be multiple of this number
|
||||
|
||||
### String Constraints
|
||||
|
||||
```python
|
||||
class User(BaseModel):
|
||||
username: str = Field(
|
||||
min_length=3,
|
||||
max_length=20,
|
||||
pattern=r'^[a-zA-Z0-9_]+$',
|
||||
description="3-20 alphanumeric characters"
|
||||
)
|
||||
bio: str = Field(max_length=500, description="Bio up to 500 chars")
|
||||
status: str = Field(pattern=r'^(active|inactive|pending)$')
|
||||
|
||||
# pattern validates against regex
|
||||
```
|
||||
|
||||
### Email and URL Validation
|
||||
|
||||
```python
|
||||
from pydantic import EmailStr, HttpUrl, AnyUrl
|
||||
|
||||
class Contact(BaseModel):
|
||||
email: EmailStr # Validates email format
|
||||
website: HttpUrl # Validates HTTP/HTTPS URLs
|
||||
portfolio: AnyUrl # Any valid URL scheme
|
||||
|
||||
contact = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Extract: john@example.com, https://example.com"
|
||||
}],
|
||||
response_model=Contact
|
||||
)
|
||||
```
|
||||
|
||||
### Date and DateTime Validation
|
||||
|
||||
```python
|
||||
from datetime import date, datetime
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
class Event(BaseModel):
|
||||
event_date: date # Validates date format
|
||||
created_at: datetime # Validates datetime format
|
||||
year: int = Field(ge=1900, le=2100)
|
||||
|
||||
@field_validator('event_date')
|
||||
def future_date(cls, v):
|
||||
"""Ensure event is in the future."""
|
||||
if v < date.today():
|
||||
raise ValueError('Event must be in the future')
|
||||
return v
|
||||
```
|
||||
|
||||
### List and Dict Validation
|
||||
|
||||
```python
|
||||
class Document(BaseModel):
|
||||
tags: list[str] = Field(min_length=1, max_length=10)
|
||||
keywords: list[str] = Field(min_length=3, description="At least 3 keywords")
|
||||
metadata: dict[str, str] = Field(description="String key-value pairs")
|
||||
|
||||
@field_validator('tags')
|
||||
def unique_tags(cls, v):
|
||||
"""Ensure tags are unique."""
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError('Tags must be unique')
|
||||
return v
|
||||
```
|
||||
|
||||
## Custom Field Validators
|
||||
|
||||
### Basic Field Validator
|
||||
|
||||
```python
|
||||
from pydantic import field_validator
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
@field_validator('name')
|
||||
def name_must_not_be_empty(cls, v):
|
||||
"""Validate name is not empty or just whitespace."""
|
||||
if not v or not v.strip():
|
||||
raise ValueError('Name cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
@field_validator('age')
|
||||
def age_must_be_reasonable(cls, v):
|
||||
"""Validate age is between 0 and 120."""
|
||||
if v < 0 or v > 120:
|
||||
raise ValueError('Age must be between 0 and 120')
|
||||
return v
|
||||
```
|
||||
|
||||
### Validator with Field Info
|
||||
|
||||
```python
|
||||
from pydantic import ValidationInfo
|
||||
|
||||
class Article(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
@field_validator('content')
|
||||
def content_length(cls, v, info: ValidationInfo):
|
||||
"""Validate content is longer than title."""
|
||||
if 'title' in info.data:
|
||||
title_len = len(info.data['title'])
|
||||
if len(v) < title_len * 2:
|
||||
raise ValueError('Content should be at least 2x title length')
|
||||
return v
|
||||
```
|
||||
|
||||
### Multiple Fields Validation
|
||||
|
||||
```python
|
||||
class TimeRange(BaseModel):
|
||||
start_time: str
|
||||
end_time: str
|
||||
|
||||
@field_validator('start_time', 'end_time')
|
||||
def valid_time_format(cls, v):
|
||||
"""Validate both times are in HH:MM format."""
|
||||
import re
|
||||
if not re.match(r'^\d{2}:\d{2}$', v):
|
||||
raise ValueError('Time must be in HH:MM format')
|
||||
return v
|
||||
```
|
||||
|
||||
### Transform and Validate
|
||||
|
||||
```python
|
||||
class URL(BaseModel):
|
||||
url: str
|
||||
|
||||
@field_validator('url')
|
||||
def normalize_url(cls, v):
|
||||
"""Add https:// if missing."""
|
||||
if not v.startswith(('http://', 'https://')):
|
||||
v = f'https://{v}'
|
||||
return v
|
||||
```
|
||||
|
||||
## Model-Level Validation
|
||||
|
||||
### Cross-Field Validation
|
||||
|
||||
```python
|
||||
from pydantic import model_validator
|
||||
|
||||
class DateRange(BaseModel):
|
||||
start_date: str
|
||||
end_date: str
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_dates(self):
|
||||
"""Ensure end_date is after start_date."""
|
||||
from datetime import datetime
|
||||
start = datetime.strptime(self.start_date, '%Y-%m-%d')
|
||||
end = datetime.strptime(self.end_date, '%Y-%m-%d')
|
||||
|
||||
if end < start:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return self
|
||||
|
||||
class PriceRange(BaseModel):
|
||||
min_price: float
|
||||
max_price: float
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_price_range(self):
|
||||
"""Ensure max > min."""
|
||||
if self.max_price <= self.min_price:
|
||||
raise ValueError('max_price must be greater than min_price')
|
||||
return self
|
||||
```
|
||||
|
||||
### Conditional Validation
|
||||
|
||||
```python
|
||||
class Order(BaseModel):
|
||||
order_type: str # "standard" or "express"
|
||||
delivery_date: str
|
||||
delivery_time: Optional[str] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_delivery_time(self):
|
||||
"""Express orders need delivery time."""
|
||||
if self.order_type == "express" and not self.delivery_time:
|
||||
raise ValueError('Express orders require delivery_time')
|
||||
return self
|
||||
```
|
||||
|
||||
### Complex Business Logic
|
||||
|
||||
```python
|
||||
class Discount(BaseModel):
|
||||
code: str
|
||||
percentage: float = Field(ge=0, le=100)
|
||||
min_purchase: float = Field(ge=0)
|
||||
max_discount: float = Field(ge=0)
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_discount(self):
|
||||
"""Ensure discount logic is sound."""
|
||||
# Max discount can't exceed percentage of min_purchase
|
||||
theoretical_max = (self.percentage / 100) * self.min_purchase
|
||||
if self.max_discount > theoretical_max:
|
||||
self.max_discount = theoretical_max
|
||||
return self
|
||||
```
|
||||
|
||||
## Complex Validation Patterns
|
||||
|
||||
### Nested Model Validation
|
||||
|
||||
```python
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
country: str
|
||||
postal_code: str
|
||||
|
||||
@field_validator('postal_code')
|
||||
def validate_postal_code(cls, v, info: ValidationInfo):
|
||||
"""Validate postal code format based on country."""
|
||||
if 'country' in info.data:
|
||||
country = info.data['country']
|
||||
if country == "USA":
|
||||
import re
|
||||
if not re.match(r'^\d{5}(-\d{4})?$', v):
|
||||
raise ValueError('Invalid US postal code')
|
||||
elif country == "Canada":
|
||||
if not re.match(r'^[A-Z]\d[A-Z] \d[A-Z]\d$', v):
|
||||
raise ValueError('Invalid Canadian postal code')
|
||||
return v
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
address: Address
|
||||
|
||||
# Nested validation runs automatically
|
||||
```
|
||||
|
||||
### List of Models
|
||||
|
||||
```python
|
||||
class Task(BaseModel):
|
||||
title: str = Field(min_length=1)
|
||||
priority: int = Field(ge=1, le=5)
|
||||
|
||||
class Project(BaseModel):
|
||||
name: str
|
||||
tasks: list[Task] = Field(min_length=1, description="At least 1 task")
|
||||
|
||||
@field_validator('tasks')
|
||||
def at_least_one_high_priority(cls, v):
|
||||
"""Ensure at least one task has priority >= 4."""
|
||||
if not any(task.priority >= 4 for task in v):
|
||||
raise ValueError('Project needs at least one high-priority task')
|
||||
return v
|
||||
```
|
||||
|
||||
### Union Type Validation
|
||||
|
||||
```python
|
||||
from typing import Union
|
||||
|
||||
class TextBlock(BaseModel):
|
||||
type: str = "text"
|
||||
content: str = Field(min_length=1)
|
||||
|
||||
class ImageBlock(BaseModel):
|
||||
type: str = "image"
|
||||
url: HttpUrl
|
||||
alt_text: str
|
||||
|
||||
class Page(BaseModel):
|
||||
title: str
|
||||
blocks: list[Union[TextBlock, ImageBlock]]
|
||||
|
||||
@field_validator('blocks')
|
||||
def validate_block_types(cls, v):
|
||||
"""Ensure first block is TextBlock."""
|
||||
if v and not isinstance(v[0], TextBlock):
|
||||
raise ValueError('First block must be text')
|
||||
return v
|
||||
```
|
||||
|
||||
### Dependent Fields
|
||||
|
||||
```python
|
||||
class Subscription(BaseModel):
|
||||
plan: str # "free", "pro", "enterprise"
|
||||
max_users: int
|
||||
features: list[str]
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_plan_limits(self):
|
||||
"""Enforce plan-specific limits."""
|
||||
limits = {
|
||||
"free": {"max_users": 1, "required_features": ["basic"]},
|
||||
"pro": {"max_users": 10, "required_features": ["basic", "advanced"]},
|
||||
"enterprise": {"max_users": 999, "required_features": ["basic", "advanced", "premium"]}
|
||||
}
|
||||
|
||||
if self.plan in limits:
|
||||
limit = limits[self.plan]
|
||||
|
||||
if self.max_users > limit["max_users"]:
|
||||
raise ValueError(f'{self.plan} plan limited to {limit["max_users"]} users')
|
||||
|
||||
for feature in limit["required_features"]:
|
||||
if feature not in self.features:
|
||||
raise ValueError(f'{self.plan} plan requires {feature} feature')
|
||||
|
||||
return self
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Graceful Degradation
|
||||
|
||||
```python
|
||||
class OptionalExtraction(BaseModel):
|
||||
# Required fields
|
||||
title: str
|
||||
|
||||
# Optional fields with defaults
|
||||
author: Optional[str] = None
|
||||
date: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
|
||||
# LLM can succeed even if it can't extract everything
|
||||
```
|
||||
|
||||
### Partial Validation
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
def extract_with_fallback(text: str):
|
||||
"""Try full extraction, fall back to partial."""
|
||||
try:
|
||||
# Try full extraction
|
||||
return client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": text}],
|
||||
response_model=FullModel
|
||||
)
|
||||
except ValidationError:
|
||||
# Fall back to partial model
|
||||
return client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": text}],
|
||||
response_model=PartialModel
|
||||
)
|
||||
```
|
||||
|
||||
### Validation Error Inspection
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
try:
|
||||
result = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
max_tokens=1024,
|
||||
messages=[...],
|
||||
response_model=MyModel,
|
||||
max_retries=3
|
||||
)
|
||||
except ValidationError as e:
|
||||
# Inspect specific errors
|
||||
for error in e.errors():
|
||||
field = error['loc'][0]
|
||||
message = error['msg']
|
||||
print(f"Field '{field}' failed: {message}")
|
||||
|
||||
# Custom handling per field
|
||||
if field == 'email':
|
||||
# Handle email validation failure
|
||||
pass
|
||||
```
|
||||
|
||||
### Custom Error Messages
|
||||
|
||||
```python
|
||||
class DetailedModel(BaseModel):
|
||||
name: str = Field(
|
||||
min_length=2,
|
||||
max_length=100,
|
||||
description="Name between 2-100 characters"
|
||||
)
|
||||
age: int = Field(
|
||||
ge=0,
|
||||
le=120,
|
||||
description="Age between 0 and 120 years"
|
||||
)
|
||||
|
||||
@field_validator('name')
|
||||
def validate_name(cls, v):
|
||||
"""Provide helpful error message."""
|
||||
if not v.strip():
|
||||
raise ValueError(
|
||||
'Name cannot be empty. '
|
||||
'Please provide a valid name from the text.'
|
||||
)
|
||||
return v
|
||||
|
||||
# When validation fails, LLM sees these helpful messages
|
||||
```
|
||||
|
||||
## Validation Best Practices
|
||||
|
||||
### 1. Be Specific
|
||||
|
||||
```python
|
||||
# ❌ Bad: Vague validation
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
|
||||
# ✅ Good: Specific constraints
|
||||
class Item(BaseModel):
|
||||
name: str = Field(
|
||||
min_length=1,
|
||||
max_length=200,
|
||||
description="Item name, 1-200 characters"
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Provide Context
|
||||
|
||||
```python
|
||||
# ✅ Good: Explain why validation failed
|
||||
@field_validator('price')
|
||||
def validate_price(cls, v):
|
||||
if v <= 0:
|
||||
raise ValueError(
|
||||
'Price must be positive. '
|
||||
'Extract numeric price from text without currency symbols.'
|
||||
)
|
||||
return v
|
||||
```
|
||||
|
||||
### 3. Use Enums for Fixed Sets
|
||||
|
||||
```python
|
||||
# ❌ Bad: String validation
|
||||
status: str
|
||||
|
||||
@field_validator('status')
|
||||
def validate_status(cls, v):
|
||||
if v not in ['active', 'inactive', 'pending']:
|
||||
raise ValueError('Invalid status')
|
||||
return v
|
||||
|
||||
# ✅ Good: Enum
|
||||
class Status(str, Enum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
PENDING = "pending"
|
||||
|
||||
status: Status # Validation automatic
|
||||
```
|
||||
|
||||
### 4. Balance Strictness
|
||||
|
||||
```python
|
||||
# Too strict: May fail unnecessarily
|
||||
class StrictModel(BaseModel):
|
||||
date: str = Field(pattern=r'^\d{4}-\d{2}-\d{2}$')
|
||||
# Fails if LLM uses "2024-1-5" instead of "2024-01-05"
|
||||
|
||||
# Better: Normalize in validator
|
||||
class FlexibleModel(BaseModel):
|
||||
date: str
|
||||
|
||||
@field_validator('date')
|
||||
def normalize_date(cls, v):
|
||||
from datetime import datetime
|
||||
# Parse flexible formats
|
||||
for fmt in ['%Y-%m-%d', '%Y/%m/%d', '%m/%d/%Y']:
|
||||
try:
|
||||
dt = datetime.strptime(v, fmt)
|
||||
return dt.strftime('%Y-%m-%d') # Normalize
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError('Invalid date format')
|
||||
```
|
||||
|
||||
### 5. Test Validation
|
||||
|
||||
```python
|
||||
# Test your validators with edge cases
|
||||
def test_validation():
|
||||
# Should succeed
|
||||
valid = MyModel(field="valid_value")
|
||||
|
||||
# Should fail
|
||||
try:
|
||||
invalid = MyModel(field="invalid")
|
||||
assert False, "Should have raised ValidationError"
|
||||
except ValidationError:
|
||||
pass # Expected
|
||||
|
||||
# Run tests before using in production
|
||||
```
|
||||
|
||||
## Advanced Techniques
|
||||
|
||||
### Conditional Required Fields
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
class ConditionalModel(BaseModel):
|
||||
type: str
|
||||
detail_a: Optional[str] = None
|
||||
detail_b: Optional[str] = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def check_required_details(self):
|
||||
"""Require different fields based on type."""
|
||||
if self.type == "type_a" and not self.detail_a:
|
||||
raise ValueError('type_a requires detail_a')
|
||||
if self.type == "type_b" and not self.detail_b:
|
||||
raise ValueError('type_b requires detail_b')
|
||||
return self
|
||||
```
|
||||
|
||||
### Validation with External Data
|
||||
|
||||
```python
|
||||
class Product(BaseModel):
|
||||
sku: str
|
||||
name: str
|
||||
|
||||
@field_validator('sku')
|
||||
def validate_sku(cls, v):
|
||||
"""Check SKU exists in database."""
|
||||
# Query database or API
|
||||
if not database.sku_exists(v):
|
||||
raise ValueError(f'SKU {v} not found in catalog')
|
||||
return v
|
||||
```
|
||||
|
||||
### Progressive Validation
|
||||
|
||||
```python
|
||||
# Start with loose validation
|
||||
class Stage1(BaseModel):
|
||||
data: str # Any string
|
||||
|
||||
# Then strict validation
|
||||
class Stage2(BaseModel):
|
||||
data: str = Field(pattern=r'^[A-Z]{3}-\d{6}$')
|
||||
|
||||
# Use Stage1 for initial extraction
|
||||
# Use Stage2 for final validation
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Pydantic Docs**: https://docs.pydantic.dev/latest/concepts/validators/
|
||||
- **Instructor Examples**: https://python.useinstructor.com/examples
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
---
|
||||
name: tensorrt-llm
|
||||
description: Optimizes LLM inference with NVIDIA TensorRT for maximum throughput and lowest latency. Use for production deployment on NVIDIA GPUs (A100/H100), when you need 10-100x faster inference than PyTorch, or for serving models with quantization (FP8/INT4), in-flight batching, and multi-GPU scaling.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [tensorrt-llm, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Inference Serving, TensorRT-LLM, NVIDIA, Inference Optimization, High Throughput, Low Latency, Production, FP8, INT4, In-Flight Batching, Multi-GPU]
|
||||
|
||||
---
|
||||
|
||||
# TensorRT-LLM
|
||||
|
||||
NVIDIA's open-source library for optimizing LLM inference with state-of-the-art performance on NVIDIA GPUs.
|
||||
|
||||
## When to use TensorRT-LLM
|
||||
|
||||
**Use TensorRT-LLM when:**
|
||||
- Deploying on NVIDIA GPUs (A100, H100, GB200)
|
||||
- Need maximum throughput (24,000+ tokens/sec on Llama 3)
|
||||
- Require low latency for real-time applications
|
||||
- Working with quantized models (FP8, INT4, FP4)
|
||||
- Scaling across multiple GPUs or nodes
|
||||
|
||||
**Use vLLM instead when:**
|
||||
- Need simpler setup and Python-first API
|
||||
- Want PagedAttention without TensorRT compilation
|
||||
- Working with AMD GPUs or non-NVIDIA hardware
|
||||
|
||||
**Use llama.cpp instead when:**
|
||||
- Deploying on CPU or Apple Silicon
|
||||
- Need edge deployment without NVIDIA GPUs
|
||||
- Want simpler GGUF quantization format
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Docker (recommended)
|
||||
docker pull nvidia/tensorrt_llm:latest
|
||||
|
||||
# pip install
|
||||
pip install tensorrt_llm==1.2.0rc3
|
||||
|
||||
# Requires CUDA 13.0.0, TensorRT 10.13.2, Python 3.10-3.12
|
||||
```
|
||||
|
||||
### Basic inference
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
# Initialize model
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
|
||||
|
||||
# Configure sampling
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
)
|
||||
|
||||
# Generate
|
||||
prompts = ["Explain quantum computing"]
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
print(output.text)
|
||||
```
|
||||
|
||||
### Serving with trtllm-serve
|
||||
|
||||
```bash
|
||||
# Start server (automatic model download and compilation)
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--tp_size 4 \ # Tensor parallelism (4 GPUs)
|
||||
--max_batch_size 256 \
|
||||
--max_num_tokens 4096
|
||||
|
||||
# Client request
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
## Key features
|
||||
|
||||
### Performance optimizations
|
||||
- **In-flight batching**: Dynamic batching during generation
|
||||
- **Paged KV cache**: Efficient memory management
|
||||
- **Flash Attention**: Optimized attention kernels
|
||||
- **Quantization**: FP8, INT4, FP4 for 2-4× faster inference
|
||||
- **CUDA graphs**: Reduced kernel launch overhead
|
||||
|
||||
### Parallelism
|
||||
- **Tensor parallelism (TP)**: Split model across GPUs
|
||||
- **Pipeline parallelism (PP)**: Layer-wise distribution
|
||||
- **Expert parallelism**: For Mixture-of-Experts models
|
||||
- **Multi-node**: Scale beyond single machine
|
||||
|
||||
### Advanced features
|
||||
- **Speculative decoding**: Faster generation with draft models
|
||||
- **LoRA serving**: Efficient multi-adapter deployment
|
||||
- **Disaggregated serving**: Separate prefill and generation
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Quantized model (FP8)
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
# Load FP8 quantized model (2× faster, 50% memory)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
dtype="fp8",
|
||||
max_num_tokens=8192
|
||||
)
|
||||
|
||||
# Inference same as before
|
||||
outputs = llm.generate(["Summarize this article..."])
|
||||
```
|
||||
|
||||
### Multi-GPU deployment
|
||||
|
||||
```python
|
||||
# Tensor parallelism across 8 GPUs
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
tensor_parallel_size=8,
|
||||
dtype="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
### Batch inference
|
||||
|
||||
```python
|
||||
# Process 100 prompts efficiently
|
||||
prompts = [f"Question {i}: ..." for i in range(100)]
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params=SamplingParams(max_tokens=200)
|
||||
)
|
||||
|
||||
# Automatic in-flight batching for maximum throughput
|
||||
```
|
||||
|
||||
## Performance benchmarks
|
||||
|
||||
**Meta Llama 3-8B** (H100 GPU):
|
||||
- Throughput: 24,000 tokens/sec
|
||||
- Latency: ~10ms per token
|
||||
- vs PyTorch: **100× faster**
|
||||
|
||||
**Llama 3-70B** (8× A100 80GB):
|
||||
- FP8 quantization: 2× faster than FP16
|
||||
- Memory: 50% reduction with FP8
|
||||
|
||||
## Supported models
|
||||
|
||||
- **LLaMA family**: Llama 2, Llama 3, CodeLlama
|
||||
- **GPT family**: GPT-2, GPT-J, GPT-NeoX
|
||||
- **Qwen**: Qwen, Qwen2, QwQ
|
||||
- **DeepSeek**: DeepSeek-V2, DeepSeek-V3
|
||||
- **Mixtral**: Mixtral-8x7B, Mixtral-8x22B
|
||||
- **Vision**: LLaVA, Phi-3-vision
|
||||
- **100+ models** on HuggingFace
|
||||
|
||||
## References
|
||||
|
||||
- **[Optimization Guide](references/optimization.md)** - Quantization, batching, KV cache tuning
|
||||
- **[Multi-GPU Setup](references/multi-gpu.md)** - Tensor/pipeline parallelism, multi-node
|
||||
- **[Serving Guide](references/serving.md)** - Production deployment, monitoring, autoscaling
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://nvidia.github.io/TensorRT-LLM/
|
||||
- **GitHub**: https://github.com/NVIDIA/TensorRT-LLM
|
||||
- **Models**: https://huggingface.co/models?library=tensorrt_llm
|
||||
|
||||
|
||||
|
|
@ -1,298 +0,0 @@
|
|||
# Multi-GPU Deployment Guide
|
||||
|
||||
Comprehensive guide to scaling TensorRT-LLM across multiple GPUs and nodes.
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Tensor Parallelism (TP)
|
||||
|
||||
**What it does**: Splits model layers across GPUs horizontally.
|
||||
|
||||
**Use case**:
|
||||
- Model fits in total GPU memory but not single GPU
|
||||
- Need low latency (single forward pass)
|
||||
- GPUs on same node (NVLink required for best performance)
|
||||
|
||||
**Example** (Llama 3-70B on 4× A100):
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
tensor_parallel_size=4, # Split across 4 GPUs
|
||||
dtype="fp16"
|
||||
)
|
||||
|
||||
# Model automatically sharded across GPUs
|
||||
# Single forward pass, low latency
|
||||
```
|
||||
|
||||
**Performance**:
|
||||
- Latency: ~Same as single GPU
|
||||
- Throughput: 4× higher (4 GPUs)
|
||||
- Communication: High (activations synced every layer)
|
||||
|
||||
### Pipeline Parallelism (PP)
|
||||
|
||||
**What it does**: Splits model layers across GPUs vertically (layer-wise).
|
||||
|
||||
**Use case**:
|
||||
- Very large models (175B+)
|
||||
- Can tolerate higher latency
|
||||
- GPUs across multiple nodes
|
||||
|
||||
**Example** (Llama 3-405B on 8× H100):
|
||||
```python
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
tensor_parallel_size=4, # TP=4 within nodes
|
||||
pipeline_parallel_size=2, # PP=2 across nodes
|
||||
dtype="fp8"
|
||||
)
|
||||
|
||||
# Total: 8 GPUs (4×2)
|
||||
# Layers 0-40: Node 1 (4 GPUs with TP)
|
||||
# Layers 41-80: Node 2 (4 GPUs with TP)
|
||||
```
|
||||
|
||||
**Performance**:
|
||||
- Latency: Higher (sequential through pipeline)
|
||||
- Throughput: High with micro-batching
|
||||
- Communication: Lower than TP
|
||||
|
||||
### Expert Parallelism (EP)
|
||||
|
||||
**What it does**: Distributes MoE experts across GPUs.
|
||||
|
||||
**Use case**: Mixture-of-Experts models (Mixtral, DeepSeek-V2)
|
||||
|
||||
**Example** (Mixtral-8x22B on 8× A100):
|
||||
```python
|
||||
llm = LLM(
|
||||
model="mistralai/Mixtral-8x22B",
|
||||
tensor_parallel_size=4,
|
||||
expert_parallel_size=2, # Distribute 8 experts across 2 groups
|
||||
dtype="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Small model (7-13B) - Single GPU
|
||||
|
||||
```python
|
||||
# Llama 3-8B on 1× A100 80GB
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
dtype="fp16" # or fp8 for H100
|
||||
)
|
||||
```
|
||||
|
||||
**Resources**:
|
||||
- GPU: 1× A100 80GB
|
||||
- Memory: ~16GB model + 30GB KV cache
|
||||
- Throughput: 3,000-5,000 tokens/sec
|
||||
|
||||
### Medium model (70B) - Multi-GPU same node
|
||||
|
||||
```python
|
||||
# Llama 3-70B on 4× A100 80GB (NVLink)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
tensor_parallel_size=4,
|
||||
dtype="fp8" # 70GB → 35GB per GPU
|
||||
)
|
||||
```
|
||||
|
||||
**Resources**:
|
||||
- GPU: 4× A100 80GB with NVLink
|
||||
- Memory: ~35GB per GPU (FP8)
|
||||
- Throughput: 10,000-15,000 tokens/sec
|
||||
- Latency: 15-20ms per token
|
||||
|
||||
### Large model (405B) - Multi-node
|
||||
|
||||
```python
|
||||
# Llama 3-405B on 2 nodes × 8 H100 = 16 GPUs
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
tensor_parallel_size=8, # TP within each node
|
||||
pipeline_parallel_size=2, # PP across 2 nodes
|
||||
dtype="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
**Resources**:
|
||||
- GPU: 2 nodes × 8 H100 80GB
|
||||
- Memory: ~25GB per GPU (FP8)
|
||||
- Throughput: 20,000-30,000 tokens/sec
|
||||
- Network: InfiniBand recommended
|
||||
|
||||
## Server Deployment
|
||||
|
||||
### Single-node multi-GPU
|
||||
|
||||
```bash
|
||||
# Llama 3-70B on 4 GPUs (automatic TP)
|
||||
trtllm-serve meta-llama/Meta-Llama-3-70B \
|
||||
--tp_size 4 \
|
||||
--max_batch_size 256 \
|
||||
--dtype fp8
|
||||
|
||||
# Listens on http://localhost:8000
|
||||
```
|
||||
|
||||
### Multi-node with Ray
|
||||
|
||||
```bash
|
||||
# Node 1 (head node)
|
||||
ray start --head --port=6379
|
||||
|
||||
# Node 2 (worker)
|
||||
ray start --address='node1:6379'
|
||||
|
||||
# Deploy across cluster
|
||||
trtllm-serve meta-llama/Meta-Llama-3-405B \
|
||||
--tp_size 8 \
|
||||
--pp_size 2 \
|
||||
--num_workers 2 \ # 2 nodes
|
||||
--dtype fp8
|
||||
```
|
||||
|
||||
### Kubernetes deployment
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: tensorrt-llm-llama3-70b
|
||||
spec:
|
||||
replicas: 1
|
||||
template:
|
||||
spec:
|
||||
containers:
|
||||
- name: trtllm
|
||||
image: nvidia/tensorrt_llm:latest
|
||||
command:
|
||||
- trtllm-serve
|
||||
- meta-llama/Meta-Llama-3-70B
|
||||
- --tp_size=4
|
||||
- --max_batch_size=256
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 4 # Request 4 GPUs
|
||||
```
|
||||
|
||||
## Parallelism Decision Tree
|
||||
|
||||
```
|
||||
Model size < 20GB?
|
||||
├─ YES: Single GPU (no parallelism)
|
||||
└─ NO: Model size < 80GB?
|
||||
├─ YES: TP=2 or TP=4 (same node)
|
||||
└─ NO: Model size < 320GB?
|
||||
├─ YES: TP=4 or TP=8 (same node, NVLink required)
|
||||
└─ NO: TP=8 + PP=2 (multi-node)
|
||||
```
|
||||
|
||||
## Communication Optimization
|
||||
|
||||
### NVLink vs PCIe
|
||||
|
||||
**NVLink** (DGX A100, HGX H100):
|
||||
- Bandwidth: 600 GB/s (A100), 900 GB/s (H100)
|
||||
- Ideal for TP (high communication)
|
||||
- **Recommended for all multi-GPU setups**
|
||||
|
||||
**PCIe**:
|
||||
- Bandwidth: 64 GB/s (PCIe 4.0 x16)
|
||||
- 10× slower than NVLink
|
||||
- Avoid TP, use PP instead
|
||||
|
||||
### InfiniBand for multi-node
|
||||
|
||||
**HDR InfiniBand** (200 Gb/s):
|
||||
- Required for multi-node TP or PP
|
||||
- Latency: <1μs
|
||||
- **Essential for 405B+ models**
|
||||
|
||||
## Monitoring Multi-GPU
|
||||
|
||||
```python
|
||||
# Monitor GPU utilization
|
||||
nvidia-smi dmon -s u
|
||||
|
||||
# Monitor memory
|
||||
nvidia-smi dmon -s m
|
||||
|
||||
# Monitor NVLink utilization
|
||||
nvidia-smi nvlink --status
|
||||
|
||||
# TensorRT-LLM built-in metrics
|
||||
curl http://localhost:8000/metrics
|
||||
```
|
||||
|
||||
**Key metrics**:
|
||||
- GPU utilization: Target 80-95%
|
||||
- Memory usage: Should be balanced across GPUs
|
||||
- NVLink traffic: High for TP, low for PP
|
||||
- Throughput: Tokens/sec across all GPUs
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Imbalanced GPU memory
|
||||
|
||||
**Symptom**: GPU 0 has 90% memory, GPU 3 has 40%
|
||||
|
||||
**Solutions**:
|
||||
- Verify TP/PP configuration
|
||||
- Check model sharding (should be equal)
|
||||
- Restart server to reset state
|
||||
|
||||
### Low NVLink utilization
|
||||
|
||||
**Symptom**: NVLink bandwidth <100 GB/s with TP=4
|
||||
|
||||
**Solutions**:
|
||||
- Verify NVLink topology: `nvidia-smi topo -m`
|
||||
- Check for PCIe fallback
|
||||
- Ensure GPUs are on same NVSwitch
|
||||
|
||||
### OOM with multi-GPU
|
||||
|
||||
**Solutions**:
|
||||
- Increase TP size (more GPUs)
|
||||
- Reduce batch size
|
||||
- Enable FP8 quantization
|
||||
- Use pipeline parallelism
|
||||
|
||||
## Performance Scaling
|
||||
|
||||
### TP Scaling (Llama 3-70B, FP8)
|
||||
|
||||
| GPUs | TP Size | Throughput | Latency | Efficiency |
|
||||
|------|---------|------------|---------|------------|
|
||||
| 1 | 1 | OOM | - | - |
|
||||
| 2 | 2 | 6,000 tok/s | 18ms | 85% |
|
||||
| 4 | 4 | 11,000 tok/s | 16ms | 78% |
|
||||
| 8 | 8 | 18,000 tok/s | 15ms | 64% |
|
||||
|
||||
**Note**: Efficiency drops with more GPUs due to communication overhead.
|
||||
|
||||
### PP Scaling (Llama 3-405B, FP8)
|
||||
|
||||
| Nodes | TP | PP | Total GPUs | Throughput |
|
||||
|-------|----|----|------------|------------|
|
||||
| 1 | 8 | 1 | 8 | OOM |
|
||||
| 2 | 8 | 2 | 16 | 25,000 tok/s |
|
||||
| 4 | 8 | 4 | 32 | 45,000 tok/s |
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Prefer TP over PP** when possible (lower latency)
|
||||
2. **Use NVLink** for all TP deployments
|
||||
3. **Use InfiniBand** for multi-node deployments
|
||||
4. **Start with smallest TP** that fits model in memory
|
||||
5. **Monitor GPU balance** - all GPUs should have similar utilization
|
||||
6. **Test with benchmark** before production
|
||||
7. **Use FP8** on H100 for 2× speedup
|
||||
|
|
@ -1,242 +0,0 @@
|
|||
# TensorRT-LLM Optimization Guide
|
||||
|
||||
Comprehensive guide to optimizing LLM inference with TensorRT-LLM.
|
||||
|
||||
## Quantization
|
||||
|
||||
### FP8 Quantization (Recommended for H100)
|
||||
|
||||
**Benefits**:
|
||||
- 2× faster inference
|
||||
- 50% memory reduction
|
||||
- Minimal accuracy loss (<1% perplexity degradation)
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
# Automatic FP8 quantization
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
dtype="fp8",
|
||||
quantization="fp8"
|
||||
)
|
||||
```
|
||||
|
||||
**Performance** (Llama 3-70B on 8× H100):
|
||||
- FP16: 5,000 tokens/sec
|
||||
- FP8: **10,000 tokens/sec** (2× speedup)
|
||||
- Memory: 140GB → 70GB
|
||||
|
||||
### INT4 Quantization (Maximum compression)
|
||||
|
||||
**Benefits**:
|
||||
- 4× memory reduction
|
||||
- 3-4× faster inference
|
||||
- Fits larger models on same hardware
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
# INT4 with AWQ calibration
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
dtype="int4_awq",
|
||||
quantization="awq"
|
||||
)
|
||||
|
||||
# INT4 with GPTQ calibration
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-405B",
|
||||
dtype="int4_gptq",
|
||||
quantization="gptq"
|
||||
)
|
||||
```
|
||||
|
||||
**Trade-offs**:
|
||||
- Accuracy: 1-3% perplexity increase
|
||||
- Speed: 3-4× faster than FP16
|
||||
- Use case: When memory is critical
|
||||
|
||||
## In-Flight Batching
|
||||
|
||||
**What it does**: Dynamically batches requests during generation instead of waiting for all sequences to finish.
|
||||
|
||||
**Configuration**:
|
||||
```python
|
||||
# Server configuration
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--max_batch_size 256 \ # Maximum concurrent sequences
|
||||
--max_num_tokens 4096 \ # Total tokens in batch
|
||||
--enable_chunked_context \ # Split long prompts
|
||||
--scheduler_policy max_utilization
|
||||
```
|
||||
|
||||
**Performance**:
|
||||
- Throughput: **4-8× higher** vs static batching
|
||||
- Latency: Lower P50/P99 for mixed workloads
|
||||
- GPU utilization: 80-95% vs 40-60%
|
||||
|
||||
## Paged KV Cache
|
||||
|
||||
**What it does**: Manages KV cache memory like OS manages virtual memory (paging).
|
||||
|
||||
**Benefits**:
|
||||
- 40-60% higher throughput
|
||||
- No memory fragmentation
|
||||
- Supports longer sequences
|
||||
|
||||
**Configuration**:
|
||||
```python
|
||||
# Automatic paged KV cache (default)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
kv_cache_free_gpu_mem_fraction=0.9, # Use 90% GPU mem for cache
|
||||
enable_prefix_caching=True # Cache common prefixes
|
||||
)
|
||||
```
|
||||
|
||||
## Speculative Decoding
|
||||
|
||||
**What it does**: Uses small draft model to predict multiple tokens, verified by target model in parallel.
|
||||
|
||||
**Speedup**: 2-3× faster for long generations
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
# Target model (Llama 3-70B)
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-70B",
|
||||
speculative_model="meta-llama/Meta-Llama-3-8B", # Draft model
|
||||
num_speculative_tokens=5 # Tokens to predict ahead
|
||||
)
|
||||
|
||||
# Same API, 2-3× faster
|
||||
outputs = llm.generate(prompts)
|
||||
```
|
||||
|
||||
**Best models for drafting**:
|
||||
- Target: Llama 3-70B → Draft: Llama 3-8B
|
||||
- Target: Qwen2-72B → Draft: Qwen2-7B
|
||||
- Same family, 8-10× smaller
|
||||
|
||||
## CUDA Graphs
|
||||
|
||||
**What it does**: Reduces kernel launch overhead by recording GPU operations.
|
||||
|
||||
**Benefits**:
|
||||
- 10-20% lower latency
|
||||
- More stable P99 latency
|
||||
- Better for small batch sizes
|
||||
|
||||
**Configuration** (automatic by default):
|
||||
```python
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
enable_cuda_graph=True, # Default: True
|
||||
cuda_graph_cache_size=2 # Cache 2 graph variants
|
||||
)
|
||||
```
|
||||
|
||||
## Chunked Context
|
||||
|
||||
**What it does**: Splits long prompts into chunks to reduce memory spikes.
|
||||
|
||||
**Use case**: Prompts >8K tokens with limited GPU memory
|
||||
|
||||
**Configuration**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--max_num_tokens 4096 \
|
||||
--enable_chunked_context \
|
||||
--max_chunked_prefill_length 2048 # Process 2K tokens at a time
|
||||
```
|
||||
|
||||
## Overlap Scheduling
|
||||
|
||||
**What it does**: Overlaps compute and memory operations.
|
||||
|
||||
**Benefits**:
|
||||
- 15-25% higher throughput
|
||||
- Better GPU utilization
|
||||
- Default in v1.2.0+
|
||||
|
||||
**No configuration needed** - enabled automatically.
|
||||
|
||||
## Quantization Comparison Table
|
||||
|
||||
| Method | Memory | Speed | Accuracy | Use Case |
|
||||
|--------|--------|-------|----------|----------|
|
||||
| FP16 | 1× (baseline) | 1× | Best | High accuracy needed |
|
||||
| FP8 | 0.5× | 2× | -0.5% ppl | **H100 default** |
|
||||
| INT4 AWQ | 0.25× | 3-4× | -1.5% ppl | Memory critical |
|
||||
| INT4 GPTQ | 0.25× | 3-4× | -2% ppl | Maximum speed |
|
||||
|
||||
## Tuning Workflow
|
||||
|
||||
1. **Start with defaults**:
|
||||
```python
|
||||
llm = LLM(model="meta-llama/Meta-Llama-3-70B")
|
||||
```
|
||||
|
||||
2. **Enable FP8** (if H100):
|
||||
```python
|
||||
llm = LLM(model="...", dtype="fp8")
|
||||
```
|
||||
|
||||
3. **Tune batch size**:
|
||||
```python
|
||||
# Increase until OOM, then reduce 20%
|
||||
trtllm-serve ... --max_batch_size 256
|
||||
```
|
||||
|
||||
4. **Enable chunked context** (if long prompts):
|
||||
```bash
|
||||
--enable_chunked_context --max_chunked_prefill_length 2048
|
||||
```
|
||||
|
||||
5. **Try speculative decoding** (if latency critical):
|
||||
```python
|
||||
llm = LLM(model="...", speculative_model="...")
|
||||
```
|
||||
|
||||
## Benchmarking
|
||||
|
||||
```bash
|
||||
# Install benchmark tool
|
||||
pip install tensorrt_llm[benchmark]
|
||||
|
||||
# Run benchmark
|
||||
python benchmarks/python/benchmark.py \
|
||||
--model meta-llama/Meta-Llama-3-8B \
|
||||
--batch_size 64 \
|
||||
--input_len 128 \
|
||||
--output_len 256 \
|
||||
--dtype fp8
|
||||
```
|
||||
|
||||
**Metrics to track**:
|
||||
- Throughput (tokens/sec)
|
||||
- Latency P50/P90/P99 (ms)
|
||||
- GPU memory usage (GB)
|
||||
- GPU utilization (%)
|
||||
|
||||
## Common Issues
|
||||
|
||||
**OOM errors**:
|
||||
- Reduce `max_batch_size`
|
||||
- Reduce `max_num_tokens`
|
||||
- Enable INT4 quantization
|
||||
- Increase `tensor_parallel_size`
|
||||
|
||||
**Low throughput**:
|
||||
- Increase `max_batch_size`
|
||||
- Enable in-flight batching
|
||||
- Verify CUDA graphs enabled
|
||||
- Check GPU utilization
|
||||
|
||||
**High latency**:
|
||||
- Try speculative decoding
|
||||
- Reduce `max_batch_size` (less queueing)
|
||||
- Use FP8 instead of FP16
|
||||
|
|
@ -1,470 +0,0 @@
|
|||
# Production Serving Guide
|
||||
|
||||
Comprehensive guide to deploying TensorRT-LLM in production environments.
|
||||
|
||||
## Server Modes
|
||||
|
||||
### trtllm-serve (Recommended)
|
||||
|
||||
**Features**:
|
||||
- OpenAI-compatible API
|
||||
- Automatic model download and compilation
|
||||
- Built-in load balancing
|
||||
- Prometheus metrics
|
||||
- Health checks
|
||||
|
||||
**Basic usage**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--tp_size 1 \
|
||||
--max_batch_size 256 \
|
||||
--port 8000
|
||||
```
|
||||
|
||||
**Advanced configuration**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-70B \
|
||||
--tp_size 4 \
|
||||
--dtype fp8 \
|
||||
--max_batch_size 256 \
|
||||
--max_num_tokens 4096 \
|
||||
--enable_chunked_context \
|
||||
--scheduler_policy max_utilization \
|
||||
--port 8000 \
|
||||
--api_key $API_KEY # Optional authentication
|
||||
```
|
||||
|
||||
### Python LLM API (For embedding)
|
||||
|
||||
```python
|
||||
from tensorrt_llm import LLM
|
||||
|
||||
class LLMService:
|
||||
def __init__(self):
|
||||
self.llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B",
|
||||
dtype="fp8"
|
||||
)
|
||||
|
||||
def generate(self, prompt, max_tokens=100):
|
||||
from tensorrt_llm import SamplingParams
|
||||
|
||||
params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.7
|
||||
)
|
||||
outputs = self.llm.generate([prompt], params)
|
||||
return outputs[0].text
|
||||
|
||||
# Use in FastAPI, Flask, etc
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
service = LLMService()
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt: str):
|
||||
return {"response": service.generate(prompt)}
|
||||
```
|
||||
|
||||
## OpenAI-Compatible API
|
||||
|
||||
### Chat Completions
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Explain quantum computing"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
**Response**:
|
||||
```json
|
||||
{
|
||||
"id": "chat-abc123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Quantum computing is..."
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 25,
|
||||
"completion_tokens": 150,
|
||||
"total_tokens": 175
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Streaming
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"messages": [{"role": "user", "content": "Count to 10"}],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
**Response** (SSE stream):
|
||||
```
|
||||
data: {"choices":[{"delta":{"content":"1"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":", 2"}}]}
|
||||
|
||||
data: {"choices":[{"delta":{"content":", 3"}}]}
|
||||
|
||||
data: [DONE]
|
||||
```
|
||||
|
||||
### Completions
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"prompt": "The capital of France is",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0
|
||||
}'
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Prometheus Metrics
|
||||
|
||||
**Enable metrics**:
|
||||
```bash
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--enable_metrics \
|
||||
--metrics_port 9090
|
||||
```
|
||||
|
||||
**Key metrics**:
|
||||
```bash
|
||||
# Scrape metrics
|
||||
curl http://localhost:9090/metrics
|
||||
|
||||
# Important metrics:
|
||||
# - trtllm_request_success_total - Total successful requests
|
||||
# - trtllm_request_latency_seconds - Request latency histogram
|
||||
# - trtllm_tokens_generated_total - Total tokens generated
|
||||
# - trtllm_active_requests - Current active requests
|
||||
# - trtllm_queue_size - Requests waiting in queue
|
||||
# - trtllm_gpu_memory_usage_bytes - GPU memory usage
|
||||
# - trtllm_kv_cache_usage_ratio - KV cache utilization
|
||||
```
|
||||
|
||||
### Health Checks
|
||||
|
||||
```bash
|
||||
# Readiness probe
|
||||
curl http://localhost:8000/health/ready
|
||||
|
||||
# Liveness probe
|
||||
curl http://localhost:8000/health/live
|
||||
|
||||
# Model info
|
||||
curl http://localhost:8000/v1/models
|
||||
```
|
||||
|
||||
**Kubernetes probes**:
|
||||
```yaml
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health/live
|
||||
port: 8000
|
||||
initialDelaySeconds: 60
|
||||
periodSeconds: 10
|
||||
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health/ready
|
||||
port: 8000
|
||||
initialDelaySeconds: 30
|
||||
periodSeconds: 5
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
### Docker Deployment
|
||||
|
||||
**Dockerfile**:
|
||||
```dockerfile
|
||||
FROM nvidia/tensorrt_llm:latest
|
||||
|
||||
# Copy any custom configs
|
||||
COPY config.yaml /app/config.yaml
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 8000 9090
|
||||
|
||||
# Start server
|
||||
CMD ["trtllm-serve", "meta-llama/Meta-Llama-3-8B", \
|
||||
"--tp_size", "4", \
|
||||
"--dtype", "fp8", \
|
||||
"--max_batch_size", "256", \
|
||||
"--enable_metrics", \
|
||||
"--metrics_port", "9090"]
|
||||
```
|
||||
|
||||
**Run container**:
|
||||
```bash
|
||||
docker run --gpus all -p 8000:8000 -p 9090:9090 \
|
||||
tensorrt-llm:latest
|
||||
```
|
||||
|
||||
### Kubernetes Deployment
|
||||
|
||||
**Complete deployment**:
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: tensorrt-llm
|
||||
spec:
|
||||
replicas: 2 # Multiple replicas for HA
|
||||
selector:
|
||||
matchLabels:
|
||||
app: tensorrt-llm
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: tensorrt-llm
|
||||
spec:
|
||||
containers:
|
||||
- name: trtllm
|
||||
image: nvidia/tensorrt_llm:latest
|
||||
command:
|
||||
- trtllm-serve
|
||||
- meta-llama/Meta-Llama-3-70B
|
||||
- --tp_size=4
|
||||
- --dtype=fp8
|
||||
- --max_batch_size=256
|
||||
- --enable_metrics
|
||||
ports:
|
||||
- containerPort: 8000
|
||||
name: http
|
||||
- containerPort: 9090
|
||||
name: metrics
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 4
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health/live
|
||||
port: 8000
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health/ready
|
||||
port: 8000
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: tensorrt-llm
|
||||
spec:
|
||||
selector:
|
||||
app: tensorrt-llm
|
||||
ports:
|
||||
- name: http
|
||||
port: 80
|
||||
targetPort: 8000
|
||||
- name: metrics
|
||||
port: 9090
|
||||
targetPort: 9090
|
||||
type: LoadBalancer
|
||||
```
|
||||
|
||||
### Load Balancing
|
||||
|
||||
**NGINX configuration**:
|
||||
```nginx
|
||||
upstream tensorrt_llm {
|
||||
least_conn; # Route to least busy server
|
||||
server trtllm-1:8000 max_fails=3 fail_timeout=30s;
|
||||
server trtllm-2:8000 max_fails=3 fail_timeout=30s;
|
||||
server trtllm-3:8000 max_fails=3 fail_timeout=30s;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
location / {
|
||||
proxy_pass http://tensorrt_llm;
|
||||
proxy_read_timeout 300s; # Long timeout for slow generations
|
||||
proxy_connect_timeout 10s;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Autoscaling
|
||||
|
||||
### Horizontal Pod Autoscaler (HPA)
|
||||
|
||||
```yaml
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: tensorrt-llm-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: tensorrt-llm
|
||||
minReplicas: 2
|
||||
maxReplicas: 10
|
||||
metrics:
|
||||
- type: Pods
|
||||
pods:
|
||||
metric:
|
||||
name: trtllm_active_requests
|
||||
target:
|
||||
type: AverageValue
|
||||
averageValue: "50" # Scale when avg >50 active requests
|
||||
```
|
||||
|
||||
### Custom Metrics
|
||||
|
||||
```yaml
|
||||
# Scale based on queue size
|
||||
- type: Pods
|
||||
pods:
|
||||
metric:
|
||||
name: trtllm_queue_size
|
||||
target:
|
||||
type: AverageValue
|
||||
averageValue: "10"
|
||||
```
|
||||
|
||||
## Cost Optimization
|
||||
|
||||
### GPU Selection
|
||||
|
||||
**A100 80GB** ($3-4/hour):
|
||||
- Use for: 70B models with FP8
|
||||
- Throughput: 10,000-15,000 tok/s (TP=4)
|
||||
- Cost per 1M tokens: $0.20-0.30
|
||||
|
||||
**H100 80GB** ($6-8/hour):
|
||||
- Use for: 70B models with FP8, 405B models
|
||||
- Throughput: 20,000-30,000 tok/s (TP=4)
|
||||
- Cost per 1M tokens: $0.15-0.25 (2× faster = lower cost)
|
||||
|
||||
**L4** ($0.50-1/hour):
|
||||
- Use for: 7-8B models
|
||||
- Throughput: 1,000-2,000 tok/s
|
||||
- Cost per 1M tokens: $0.25-0.50
|
||||
|
||||
### Batch Size Tuning
|
||||
|
||||
**Impact on cost**:
|
||||
- Batch size 1: 1,000 tok/s → $3/hour per 1M = $3/M tokens
|
||||
- Batch size 64: 5,000 tok/s → $3/hour per 5M = $0.60/M tokens
|
||||
- **5× cost reduction** with batching
|
||||
|
||||
**Recommendation**: Target batch size 32-128 for cost efficiency.
|
||||
|
||||
## Security
|
||||
|
||||
### API Authentication
|
||||
|
||||
```bash
|
||||
# Generate API key
|
||||
export API_KEY=$(openssl rand -hex 32)
|
||||
|
||||
# Start server with authentication
|
||||
trtllm-serve meta-llama/Meta-Llama-3-8B \
|
||||
--api_key $API_KEY
|
||||
|
||||
# Client request
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Authorization: Bearer $API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "...", "messages": [...]}'
|
||||
```
|
||||
|
||||
### Network Policies
|
||||
|
||||
```yaml
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: NetworkPolicy
|
||||
metadata:
|
||||
name: tensorrt-llm-policy
|
||||
spec:
|
||||
podSelector:
|
||||
matchLabels:
|
||||
app: tensorrt-llm
|
||||
policyTypes:
|
||||
- Ingress
|
||||
ingress:
|
||||
- from:
|
||||
- podSelector:
|
||||
matchLabels:
|
||||
app: api-gateway # Only allow from gateway
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 8000
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### High latency
|
||||
|
||||
**Diagnosis**:
|
||||
```bash
|
||||
# Check queue size
|
||||
curl http://localhost:9090/metrics | grep queue_size
|
||||
|
||||
# Check active requests
|
||||
curl http://localhost:9090/metrics | grep active_requests
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
- Scale horizontally (more replicas)
|
||||
- Increase batch size (if GPU underutilized)
|
||||
- Enable chunked context (if long prompts)
|
||||
- Use FP8 quantization
|
||||
|
||||
### OOM crashes
|
||||
|
||||
**Solutions**:
|
||||
- Reduce `max_batch_size`
|
||||
- Reduce `max_num_tokens`
|
||||
- Enable FP8 or INT4 quantization
|
||||
- Increase `tensor_parallel_size`
|
||||
|
||||
### Timeout errors
|
||||
|
||||
**NGINX config**:
|
||||
```nginx
|
||||
proxy_read_timeout 600s; # 10 minutes for very long generations
|
||||
proxy_send_timeout 600s;
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use FP8 on H100** for 2× speedup and 50% cost reduction
|
||||
2. **Monitor metrics** - Set up Prometheus + Grafana
|
||||
3. **Set readiness probes** - Prevent routing to unhealthy pods
|
||||
4. **Use load balancing** - Distribute load across replicas
|
||||
5. **Tune batch size** - Balance latency and throughput
|
||||
6. **Enable streaming** - Better UX for chat applications
|
||||
7. **Set up autoscaling** - Handle traffic spikes
|
||||
8. **Use persistent volumes** - Cache compiled models
|
||||
9. **Implement retries** - Handle transient failures
|
||||
10. **Monitor costs** - Track cost per token
|
||||
Loading…
Add table
Add a link
Reference in a new issue