mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-26 01:01:40 +00:00
- Restored 21 skills removed in commits757d012and740dd92: accelerate, audiocraft, code-review, faiss, flash-attention, gguf, grpo-rl-training, guidance, llava, nemo-curator, obliteratus, peft, pytorch-fsdp, pytorch-lightning, simpo, slime, stable-diffusion, tensorrt-llm, torchtitan, trl-fine-tuning, whisper - Rewrote sync_skills() with proper update semantics: * New skills (not in manifest): copied to user dir * Existing skills (in manifest + on disk): updated via hash comparison * User-deleted skills (in manifest, not on disk): respected, not re-added * Stale manifest entries (removed from bundled): cleaned from manifest - Added sync_skills() to CLI startup (cmd_chat) and gateway startup (start_gateway) — previously only ran during 'hermes update' - Updated cmd_update output to show new/updated/cleaned counts - Rewrote tests: 20 tests covering manifest CRUD, dir hashing, fresh install, user deletion respect, update detection, stale cleanup, and name collision handling 75 bundled skills total. 2002 tests pass.
12 KiB
12 KiB
PyTorch Lightning Callbacks
Overview
Callbacks add functionality to training without modifying the LightningModule. They capture non-essential logic like checkpointing, early stopping, and logging.
Built-In Callbacks
1. ModelCheckpoint
Saves best models during training:
from lightning.pytorch.callbacks import ModelCheckpoint
# Save top 3 models based on validation loss
checkpoint = ModelCheckpoint(
dirpath='checkpoints/',
filename='model-{epoch:02d}-{val_loss:.2f}',
monitor='val_loss',
mode='min',
save_top_k=3,
save_last=True, # Also save last epoch
verbose=True
)
trainer = L.Trainer(callbacks=[checkpoint])
trainer.fit(model, train_loader, val_loader)
Configuration options:
checkpoint = ModelCheckpoint(
monitor='val_acc', # Metric to monitor
mode='max', # 'max' for accuracy, 'min' for loss
save_top_k=5, # Keep best 5 models
save_last=True, # Save last epoch separately
every_n_epochs=1, # Save every N epochs
save_on_train_epoch_end=False, # Save on validation end instead
filename='best-{epoch}-{val_acc:.3f}', # Naming pattern
auto_insert_metric_name=False # Don't auto-add metric to filename
)
Load checkpoint:
# Load best model
best_model_path = checkpoint.best_model_path
model = LitModel.load_from_checkpoint(best_model_path)
# Resume training
trainer = L.Trainer(callbacks=[checkpoint])
trainer.fit(model, train_loader, val_loader, ckpt_path='checkpoints/last.ckpt')
2. EarlyStopping
Stops training when metric stops improving:
from lightning.pytorch.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor='val_loss',
patience=5, # Wait 5 epochs
mode='min',
min_delta=0.001, # Minimum change to qualify as improvement
verbose=True,
strict=True, # Crash if monitored metric not found
check_on_train_epoch_end=False # Check on validation end
)
trainer = L.Trainer(callbacks=[early_stop])
trainer.fit(model, train_loader, val_loader)
# Stops automatically if no improvement for 5 epochs
Advanced usage:
early_stop = EarlyStopping(
monitor='val_loss',
patience=10,
min_delta=0.0,
verbose=True,
mode='min',
stopping_threshold=0.1, # Stop if val_loss < 0.1
divergence_threshold=5.0, # Stop if val_loss > 5.0
check_finite=True # Stop on NaN/Inf
)
3. LearningRateMonitor
Logs learning rate:
from lightning.pytorch.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(
logging_interval='epoch', # Or 'step'
log_momentum=True # Also log momentum
)
trainer = L.Trainer(callbacks=[lr_monitor])
# Learning rate automatically logged to TensorBoard/WandB
4. TQDMProgressBar
Customizes progress bar:
from lightning.pytorch.callbacks import TQDMProgressBar
progress_bar = TQDMProgressBar(
refresh_rate=10, # Update every 10 batches
process_position=0
)
trainer = L.Trainer(callbacks=[progress_bar])
5. GradientAccumulationScheduler
Dynamic gradient accumulation:
from lightning.pytorch.callbacks import GradientAccumulationScheduler
# Accumulate more gradients as training progresses
accumulator = GradientAccumulationScheduler(
scheduling={
0: 8, # Epochs 0-4: accumulate 8 batches
5: 4, # Epochs 5-9: accumulate 4 batches
10: 2 # Epochs 10+: accumulate 2 batches
}
)
trainer = L.Trainer(callbacks=[accumulator])
6. StochasticWeightAveraging (SWA)
Averages weights for better generalization:
from lightning.pytorch.callbacks import StochasticWeightAveraging
swa = StochasticWeightAveraging(
swa_lrs=1e-2, # SWA learning rate
swa_epoch_start=0.8, # Start at 80% of training
annealing_epochs=10, # Annealing period
annealing_strategy='cos' # 'cos' or 'linear'
)
trainer = L.Trainer(callbacks=[swa])
Custom Callbacks
Basic Custom Callback
from lightning.pytorch.callbacks import Callback
class PrintingCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training is starting!")
def on_train_end(self, trainer, pl_module):
print("Training is done!")
def on_epoch_end(self, trainer, pl_module):
print(f"Epoch {trainer.current_epoch} ended")
# Use it
trainer = L.Trainer(callbacks=[PrintingCallback()])
Advanced Custom Callback
class MetricsCallback(Callback):
"""Logs custom metrics every N batches."""
def __init__(self, log_every_n_batches=100):
self.log_every_n_batches = log_every_n_batches
self.metrics = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx % self.log_every_n_batches == 0:
# Compute custom metric
metric = self.compute_metric(outputs)
self.metrics.append(metric)
# Log to Lightning
pl_module.log('custom_metric', metric)
def compute_metric(self, outputs):
# Your custom logic
return outputs['loss'].item()
def state_dict(self):
"""Save callback state in checkpoint."""
return {'metrics': self.metrics}
def load_state_dict(self, state_dict):
"""Restore callback state from checkpoint."""
self.metrics = state_dict['metrics']
Gradient Monitoring Callback
class GradientMonitorCallback(Callback):
"""Monitor gradient norms."""
def on_after_backward(self, trainer, pl_module):
# Compute gradient norm
total_norm = 0.0
for p in pl_module.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
# Log
pl_module.log('grad_norm', total_norm)
# Warn if exploding
if total_norm > 100:
print(f"Warning: Large gradient norm: {total_norm:.2f}")
Model Inspection Callback
class ModelInspectionCallback(Callback):
"""Inspect model activations during training."""
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if batch_idx == 0: # First batch of epoch
# Register hooks
self.activations = {}
def get_activation(name):
def hook(model, input, output):
self.activations[name] = output.detach()
return hook
# Attach to specific layers
pl_module.model.layer1.register_forward_hook(get_activation('layer1'))
pl_module.model.layer2.register_forward_hook(get_activation('layer2'))
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx == 0:
# Log activation statistics
for name, activation in self.activations.items():
mean = activation.mean().item()
std = activation.std().item()
pl_module.log(f'{name}_mean', mean)
pl_module.log(f'{name}_std', std)
Callback Hooks
All available hooks:
class MyCallback(Callback):
# Setup/Teardown
def setup(self, trainer, pl_module, stage):
"""Called at beginning of fit/test/predict."""
pass
def teardown(self, trainer, pl_module, stage):
"""Called at end of fit/test/predict."""
pass
# Training
def on_train_start(self, trainer, pl_module):
pass
def on_train_epoch_start(self, trainer, pl_module):
pass
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
pass
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
pass
def on_train_epoch_end(self, trainer, pl_module):
pass
def on_train_end(self, trainer, pl_module):
pass
# Validation
def on_validation_start(self, trainer, pl_module):
pass
def on_validation_epoch_start(self, trainer, pl_module):
pass
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
pass
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
pass
def on_validation_epoch_end(self, trainer, pl_module):
pass
def on_validation_end(self, trainer, pl_module):
pass
# Test (same structure as validation)
def on_test_start(self, trainer, pl_module):
pass
# ... (test_epoch_start, test_batch_start, etc.)
# Predict
def on_predict_start(self, trainer, pl_module):
pass
# ... (predict_epoch_start, predict_batch_start, etc.)
# Backward
def on_before_backward(self, trainer, pl_module, loss):
pass
def on_after_backward(self, trainer, pl_module):
pass
# Optimizer
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
pass
# Checkpointing
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
"""Add data to checkpoint."""
pass
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
"""Restore data from checkpoint."""
pass
Combining Multiple Callbacks
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
# Create all callbacks
checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=3)
early_stop = EarlyStopping(monitor='val_loss', patience=5)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
custom_callback = MyCustomCallback()
# Add all to Trainer
trainer = L.Trainer(
callbacks=[checkpoint, early_stop, lr_monitor, custom_callback]
)
trainer.fit(model, train_loader, val_loader)
Execution order: Callbacks execute in the order they're added
Best Practices
1. Keep Callbacks Independent
Bad (dependent on other callback):
class BadCallback(Callback):
def on_train_end(self, trainer, pl_module):
# Assumes ModelCheckpoint is present
best_path = trainer.checkpoint_callback.best_model_path # Fragile!
Good (self-contained):
class GoodCallback(Callback):
def on_train_end(self, trainer, pl_module):
# Find checkpoint callback if present
for callback in trainer.callbacks:
if isinstance(callback, ModelCheckpoint):
best_path = callback.best_model_path
break
2. Use State Dict for Persistence
class StatefulCallback(Callback):
def __init__(self):
self.counter = 0
self.history = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.counter += 1
self.history.append(outputs['loss'].item())
def state_dict(self):
"""Save state."""
return {
'counter': self.counter,
'history': self.history
}
def load_state_dict(self, state_dict):
"""Restore state."""
self.counter = state_dict['counter']
self.history = state_dict['history']
3. Handle Distributed Training
class DistributedCallback(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Only run on main process
if trainer.is_global_zero:
print("This only prints once in distributed training")
# Run on all processes
loss = outputs['loss']
# ... do something with loss on each GPU