mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-26 01:01:40 +00:00
- Restored 21 skills removed in commits757d012and740dd92: accelerate, audiocraft, code-review, faiss, flash-attention, gguf, grpo-rl-training, guidance, llava, nemo-curator, obliteratus, peft, pytorch-fsdp, pytorch-lightning, simpo, slime, stable-diffusion, tensorrt-llm, torchtitan, trl-fine-tuning, whisper - Rewrote sync_skills() with proper update semantics: * New skills (not in manifest): copied to user dir * Existing skills (in manifest + on disk): updated via hash comparison * User-deleted skills (in manifest, not on disk): respected, not re-added * Stale manifest entries (removed from bundled): cleaned from manifest - Added sync_skills() to CLI startup (cmd_chat) and gateway startup (start_gateway) — previously only ran during 'hermes update' - Updated cmd_update output to show new/updated/cleaned counts - Rewrote tests: 20 tests covering manifest CRUD, dir hashing, fresh install, user deletion respect, update detection, stale cleanup, and name collision handling 75 bundled skills total. 2002 tests pass.
436 lines
12 KiB
Markdown
436 lines
12 KiB
Markdown
# PyTorch Lightning Callbacks
|
|
|
|
## Overview
|
|
|
|
Callbacks add functionality to training without modifying the LightningModule. They capture **non-essential logic** like checkpointing, early stopping, and logging.
|
|
|
|
## Built-In Callbacks
|
|
|
|
### 1. ModelCheckpoint
|
|
|
|
**Saves best models during training**:
|
|
|
|
```python
|
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
|
|
# Save top 3 models based on validation loss
|
|
checkpoint = ModelCheckpoint(
|
|
dirpath='checkpoints/',
|
|
filename='model-{epoch:02d}-{val_loss:.2f}',
|
|
monitor='val_loss',
|
|
mode='min',
|
|
save_top_k=3,
|
|
save_last=True, # Also save last epoch
|
|
verbose=True
|
|
)
|
|
|
|
trainer = L.Trainer(callbacks=[checkpoint])
|
|
trainer.fit(model, train_loader, val_loader)
|
|
```
|
|
|
|
**Configuration options**:
|
|
```python
|
|
checkpoint = ModelCheckpoint(
|
|
monitor='val_acc', # Metric to monitor
|
|
mode='max', # 'max' for accuracy, 'min' for loss
|
|
save_top_k=5, # Keep best 5 models
|
|
save_last=True, # Save last epoch separately
|
|
every_n_epochs=1, # Save every N epochs
|
|
save_on_train_epoch_end=False, # Save on validation end instead
|
|
filename='best-{epoch}-{val_acc:.3f}', # Naming pattern
|
|
auto_insert_metric_name=False # Don't auto-add metric to filename
|
|
)
|
|
```
|
|
|
|
**Load checkpoint**:
|
|
```python
|
|
# Load best model
|
|
best_model_path = checkpoint.best_model_path
|
|
model = LitModel.load_from_checkpoint(best_model_path)
|
|
|
|
# Resume training
|
|
trainer = L.Trainer(callbacks=[checkpoint])
|
|
trainer.fit(model, train_loader, val_loader, ckpt_path='checkpoints/last.ckpt')
|
|
```
|
|
|
|
### 2. EarlyStopping
|
|
|
|
**Stops training when metric stops improving**:
|
|
|
|
```python
|
|
from lightning.pytorch.callbacks import EarlyStopping
|
|
|
|
early_stop = EarlyStopping(
|
|
monitor='val_loss',
|
|
patience=5, # Wait 5 epochs
|
|
mode='min',
|
|
min_delta=0.001, # Minimum change to qualify as improvement
|
|
verbose=True,
|
|
strict=True, # Crash if monitored metric not found
|
|
check_on_train_epoch_end=False # Check on validation end
|
|
)
|
|
|
|
trainer = L.Trainer(callbacks=[early_stop])
|
|
trainer.fit(model, train_loader, val_loader)
|
|
# Stops automatically if no improvement for 5 epochs
|
|
```
|
|
|
|
**Advanced usage**:
|
|
```python
|
|
early_stop = EarlyStopping(
|
|
monitor='val_loss',
|
|
patience=10,
|
|
min_delta=0.0,
|
|
verbose=True,
|
|
mode='min',
|
|
stopping_threshold=0.1, # Stop if val_loss < 0.1
|
|
divergence_threshold=5.0, # Stop if val_loss > 5.0
|
|
check_finite=True # Stop on NaN/Inf
|
|
)
|
|
```
|
|
|
|
### 3. LearningRateMonitor
|
|
|
|
**Logs learning rate**:
|
|
|
|
```python
|
|
from lightning.pytorch.callbacks import LearningRateMonitor
|
|
|
|
lr_monitor = LearningRateMonitor(
|
|
logging_interval='epoch', # Or 'step'
|
|
log_momentum=True # Also log momentum
|
|
)
|
|
|
|
trainer = L.Trainer(callbacks=[lr_monitor])
|
|
# Learning rate automatically logged to TensorBoard/WandB
|
|
```
|
|
|
|
### 4. TQDMProgressBar
|
|
|
|
**Customizes progress bar**:
|
|
|
|
```python
|
|
from lightning.pytorch.callbacks import TQDMProgressBar
|
|
|
|
progress_bar = TQDMProgressBar(
|
|
refresh_rate=10, # Update every 10 batches
|
|
process_position=0
|
|
)
|
|
|
|
trainer = L.Trainer(callbacks=[progress_bar])
|
|
```
|
|
|
|
### 5. GradientAccumulationScheduler
|
|
|
|
**Dynamic gradient accumulation**:
|
|
|
|
```python
|
|
from lightning.pytorch.callbacks import GradientAccumulationScheduler
|
|
|
|
# Accumulate more gradients as training progresses
|
|
accumulator = GradientAccumulationScheduler(
|
|
scheduling={
|
|
0: 8, # Epochs 0-4: accumulate 8 batches
|
|
5: 4, # Epochs 5-9: accumulate 4 batches
|
|
10: 2 # Epochs 10+: accumulate 2 batches
|
|
}
|
|
)
|
|
|
|
trainer = L.Trainer(callbacks=[accumulator])
|
|
```
|
|
|
|
### 6. StochasticWeightAveraging (SWA)
|
|
|
|
**Averages weights for better generalization**:
|
|
|
|
```python
|
|
from lightning.pytorch.callbacks import StochasticWeightAveraging
|
|
|
|
swa = StochasticWeightAveraging(
|
|
swa_lrs=1e-2, # SWA learning rate
|
|
swa_epoch_start=0.8, # Start at 80% of training
|
|
annealing_epochs=10, # Annealing period
|
|
annealing_strategy='cos' # 'cos' or 'linear'
|
|
)
|
|
|
|
trainer = L.Trainer(callbacks=[swa])
|
|
```
|
|
|
|
## Custom Callbacks
|
|
|
|
### Basic Custom Callback
|
|
|
|
```python
|
|
from lightning.pytorch.callbacks import Callback
|
|
|
|
class PrintingCallback(Callback):
|
|
def on_train_start(self, trainer, pl_module):
|
|
print("Training is starting!")
|
|
|
|
def on_train_end(self, trainer, pl_module):
|
|
print("Training is done!")
|
|
|
|
def on_epoch_end(self, trainer, pl_module):
|
|
print(f"Epoch {trainer.current_epoch} ended")
|
|
|
|
# Use it
|
|
trainer = L.Trainer(callbacks=[PrintingCallback()])
|
|
```
|
|
|
|
### Advanced Custom Callback
|
|
|
|
```python
|
|
class MetricsCallback(Callback):
|
|
"""Logs custom metrics every N batches."""
|
|
|
|
def __init__(self, log_every_n_batches=100):
|
|
self.log_every_n_batches = log_every_n_batches
|
|
self.metrics = []
|
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
if batch_idx % self.log_every_n_batches == 0:
|
|
# Compute custom metric
|
|
metric = self.compute_metric(outputs)
|
|
self.metrics.append(metric)
|
|
|
|
# Log to Lightning
|
|
pl_module.log('custom_metric', metric)
|
|
|
|
def compute_metric(self, outputs):
|
|
# Your custom logic
|
|
return outputs['loss'].item()
|
|
|
|
def state_dict(self):
|
|
"""Save callback state in checkpoint."""
|
|
return {'metrics': self.metrics}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Restore callback state from checkpoint."""
|
|
self.metrics = state_dict['metrics']
|
|
```
|
|
|
|
### Gradient Monitoring Callback
|
|
|
|
```python
|
|
class GradientMonitorCallback(Callback):
|
|
"""Monitor gradient norms."""
|
|
|
|
def on_after_backward(self, trainer, pl_module):
|
|
# Compute gradient norm
|
|
total_norm = 0.0
|
|
for p in pl_module.parameters():
|
|
if p.grad is not None:
|
|
param_norm = p.grad.data.norm(2)
|
|
total_norm += param_norm.item() ** 2
|
|
total_norm = total_norm ** 0.5
|
|
|
|
# Log
|
|
pl_module.log('grad_norm', total_norm)
|
|
|
|
# Warn if exploding
|
|
if total_norm > 100:
|
|
print(f"Warning: Large gradient norm: {total_norm:.2f}")
|
|
```
|
|
|
|
### Model Inspection Callback
|
|
|
|
```python
|
|
class ModelInspectionCallback(Callback):
|
|
"""Inspect model activations during training."""
|
|
|
|
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
|
if batch_idx == 0: # First batch of epoch
|
|
# Register hooks
|
|
self.activations = {}
|
|
|
|
def get_activation(name):
|
|
def hook(model, input, output):
|
|
self.activations[name] = output.detach()
|
|
return hook
|
|
|
|
# Attach to specific layers
|
|
pl_module.model.layer1.register_forward_hook(get_activation('layer1'))
|
|
pl_module.model.layer2.register_forward_hook(get_activation('layer2'))
|
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
if batch_idx == 0:
|
|
# Log activation statistics
|
|
for name, activation in self.activations.items():
|
|
mean = activation.mean().item()
|
|
std = activation.std().item()
|
|
pl_module.log(f'{name}_mean', mean)
|
|
pl_module.log(f'{name}_std', std)
|
|
```
|
|
|
|
## Callback Hooks
|
|
|
|
**All available hooks**:
|
|
|
|
```python
|
|
class MyCallback(Callback):
|
|
# Setup/Teardown
|
|
def setup(self, trainer, pl_module, stage):
|
|
"""Called at beginning of fit/test/predict."""
|
|
pass
|
|
|
|
def teardown(self, trainer, pl_module, stage):
|
|
"""Called at end of fit/test/predict."""
|
|
pass
|
|
|
|
# Training
|
|
def on_train_start(self, trainer, pl_module):
|
|
pass
|
|
|
|
def on_train_epoch_start(self, trainer, pl_module):
|
|
pass
|
|
|
|
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
|
pass
|
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
pass
|
|
|
|
def on_train_epoch_end(self, trainer, pl_module):
|
|
pass
|
|
|
|
def on_train_end(self, trainer, pl_module):
|
|
pass
|
|
|
|
# Validation
|
|
def on_validation_start(self, trainer, pl_module):
|
|
pass
|
|
|
|
def on_validation_epoch_start(self, trainer, pl_module):
|
|
pass
|
|
|
|
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
|
pass
|
|
|
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
|
pass
|
|
|
|
def on_validation_epoch_end(self, trainer, pl_module):
|
|
pass
|
|
|
|
def on_validation_end(self, trainer, pl_module):
|
|
pass
|
|
|
|
# Test (same structure as validation)
|
|
def on_test_start(self, trainer, pl_module):
|
|
pass
|
|
# ... (test_epoch_start, test_batch_start, etc.)
|
|
|
|
# Predict
|
|
def on_predict_start(self, trainer, pl_module):
|
|
pass
|
|
# ... (predict_epoch_start, predict_batch_start, etc.)
|
|
|
|
# Backward
|
|
def on_before_backward(self, trainer, pl_module, loss):
|
|
pass
|
|
|
|
def on_after_backward(self, trainer, pl_module):
|
|
pass
|
|
|
|
# Optimizer
|
|
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
|
|
pass
|
|
|
|
# Checkpointing
|
|
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
|
"""Add data to checkpoint."""
|
|
pass
|
|
|
|
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
|
|
"""Restore data from checkpoint."""
|
|
pass
|
|
```
|
|
|
|
## Combining Multiple Callbacks
|
|
|
|
```python
|
|
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
|
|
|
|
# Create all callbacks
|
|
checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=3)
|
|
early_stop = EarlyStopping(monitor='val_loss', patience=5)
|
|
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
|
custom_callback = MyCustomCallback()
|
|
|
|
# Add all to Trainer
|
|
trainer = L.Trainer(
|
|
callbacks=[checkpoint, early_stop, lr_monitor, custom_callback]
|
|
)
|
|
|
|
trainer.fit(model, train_loader, val_loader)
|
|
```
|
|
|
|
**Execution order**: Callbacks execute in the order they're added
|
|
|
|
## Best Practices
|
|
|
|
### 1. Keep Callbacks Independent
|
|
|
|
**Bad** (dependent on other callback):
|
|
```python
|
|
class BadCallback(Callback):
|
|
def on_train_end(self, trainer, pl_module):
|
|
# Assumes ModelCheckpoint is present
|
|
best_path = trainer.checkpoint_callback.best_model_path # Fragile!
|
|
```
|
|
|
|
**Good** (self-contained):
|
|
```python
|
|
class GoodCallback(Callback):
|
|
def on_train_end(self, trainer, pl_module):
|
|
# Find checkpoint callback if present
|
|
for callback in trainer.callbacks:
|
|
if isinstance(callback, ModelCheckpoint):
|
|
best_path = callback.best_model_path
|
|
break
|
|
```
|
|
|
|
### 2. Use State Dict for Persistence
|
|
|
|
```python
|
|
class StatefulCallback(Callback):
|
|
def __init__(self):
|
|
self.counter = 0
|
|
self.history = []
|
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
self.counter += 1
|
|
self.history.append(outputs['loss'].item())
|
|
|
|
def state_dict(self):
|
|
"""Save state."""
|
|
return {
|
|
'counter': self.counter,
|
|
'history': self.history
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Restore state."""
|
|
self.counter = state_dict['counter']
|
|
self.history = state_dict['history']
|
|
```
|
|
|
|
### 3. Handle Distributed Training
|
|
|
|
```python
|
|
class DistributedCallback(Callback):
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
# Only run on main process
|
|
if trainer.is_global_zero:
|
|
print("This only prints once in distributed training")
|
|
|
|
# Run on all processes
|
|
loss = outputs['loss']
|
|
# ... do something with loss on each GPU
|
|
```
|
|
|
|
## Resources
|
|
|
|
- Callback API: https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html
|
|
- Built-in callbacks: https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks
|
|
- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/callbacks
|