mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
- Introduced new skills tools: `skills_categories`, `skills_list`, and `skill_view` in `model_tools.py`, allowing for better organization and access to skill-related functionalities. - Updated `toolsets.py` to include a new `skills` toolset, providing a dedicated space for skill tools. - Enhanced `batch_runner.py` to recognize and validate skills tools during batch processing. - Added comprehensive tool definitions for skills tools, ensuring compatibility with OpenAI's expected format. - Created new shell script `test_skills_kimi.sh` for testing skills tool functionality with Kimi K2.5. - Added example skill files demonstrating the structure and usage of skills within the Hermes-Agent framework, including `SKILL.md` for example and audiocraft skills. - Improved documentation for skills tools and their integration into the existing tool framework, ensuring clarity for future development and usage.
12 KiB
12 KiB
PyTorch Lightning Callbacks
Overview
Callbacks add functionality to training without modifying the LightningModule. They capture non-essential logic like checkpointing, early stopping, and logging.
Built-In Callbacks
1. ModelCheckpoint
Saves best models during training:
from lightning.pytorch.callbacks import ModelCheckpoint
# Save top 3 models based on validation loss
checkpoint = ModelCheckpoint(
dirpath='checkpoints/',
filename='model-{epoch:02d}-{val_loss:.2f}',
monitor='val_loss',
mode='min',
save_top_k=3,
save_last=True, # Also save last epoch
verbose=True
)
trainer = L.Trainer(callbacks=[checkpoint])
trainer.fit(model, train_loader, val_loader)
Configuration options:
checkpoint = ModelCheckpoint(
monitor='val_acc', # Metric to monitor
mode='max', # 'max' for accuracy, 'min' for loss
save_top_k=5, # Keep best 5 models
save_last=True, # Save last epoch separately
every_n_epochs=1, # Save every N epochs
save_on_train_epoch_end=False, # Save on validation end instead
filename='best-{epoch}-{val_acc:.3f}', # Naming pattern
auto_insert_metric_name=False # Don't auto-add metric to filename
)
Load checkpoint:
# Load best model
best_model_path = checkpoint.best_model_path
model = LitModel.load_from_checkpoint(best_model_path)
# Resume training
trainer = L.Trainer(callbacks=[checkpoint])
trainer.fit(model, train_loader, val_loader, ckpt_path='checkpoints/last.ckpt')
2. EarlyStopping
Stops training when metric stops improving:
from lightning.pytorch.callbacks import EarlyStopping
early_stop = EarlyStopping(
monitor='val_loss',
patience=5, # Wait 5 epochs
mode='min',
min_delta=0.001, # Minimum change to qualify as improvement
verbose=True,
strict=True, # Crash if monitored metric not found
check_on_train_epoch_end=False # Check on validation end
)
trainer = L.Trainer(callbacks=[early_stop])
trainer.fit(model, train_loader, val_loader)
# Stops automatically if no improvement for 5 epochs
Advanced usage:
early_stop = EarlyStopping(
monitor='val_loss',
patience=10,
min_delta=0.0,
verbose=True,
mode='min',
stopping_threshold=0.1, # Stop if val_loss < 0.1
divergence_threshold=5.0, # Stop if val_loss > 5.0
check_finite=True # Stop on NaN/Inf
)
3. LearningRateMonitor
Logs learning rate:
from lightning.pytorch.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(
logging_interval='epoch', # Or 'step'
log_momentum=True # Also log momentum
)
trainer = L.Trainer(callbacks=[lr_monitor])
# Learning rate automatically logged to TensorBoard/WandB
4. TQDMProgressBar
Customizes progress bar:
from lightning.pytorch.callbacks import TQDMProgressBar
progress_bar = TQDMProgressBar(
refresh_rate=10, # Update every 10 batches
process_position=0
)
trainer = L.Trainer(callbacks=[progress_bar])
5. GradientAccumulationScheduler
Dynamic gradient accumulation:
from lightning.pytorch.callbacks import GradientAccumulationScheduler
# Accumulate more gradients as training progresses
accumulator = GradientAccumulationScheduler(
scheduling={
0: 8, # Epochs 0-4: accumulate 8 batches
5: 4, # Epochs 5-9: accumulate 4 batches
10: 2 # Epochs 10+: accumulate 2 batches
}
)
trainer = L.Trainer(callbacks=[accumulator])
6. StochasticWeightAveraging (SWA)
Averages weights for better generalization:
from lightning.pytorch.callbacks import StochasticWeightAveraging
swa = StochasticWeightAveraging(
swa_lrs=1e-2, # SWA learning rate
swa_epoch_start=0.8, # Start at 80% of training
annealing_epochs=10, # Annealing period
annealing_strategy='cos' # 'cos' or 'linear'
)
trainer = L.Trainer(callbacks=[swa])
Custom Callbacks
Basic Custom Callback
from lightning.pytorch.callbacks import Callback
class PrintingCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training is starting!")
def on_train_end(self, trainer, pl_module):
print("Training is done!")
def on_epoch_end(self, trainer, pl_module):
print(f"Epoch {trainer.current_epoch} ended")
# Use it
trainer = L.Trainer(callbacks=[PrintingCallback()])
Advanced Custom Callback
class MetricsCallback(Callback):
"""Logs custom metrics every N batches."""
def __init__(self, log_every_n_batches=100):
self.log_every_n_batches = log_every_n_batches
self.metrics = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx % self.log_every_n_batches == 0:
# Compute custom metric
metric = self.compute_metric(outputs)
self.metrics.append(metric)
# Log to Lightning
pl_module.log('custom_metric', metric)
def compute_metric(self, outputs):
# Your custom logic
return outputs['loss'].item()
def state_dict(self):
"""Save callback state in checkpoint."""
return {'metrics': self.metrics}
def load_state_dict(self, state_dict):
"""Restore callback state from checkpoint."""
self.metrics = state_dict['metrics']
Gradient Monitoring Callback
class GradientMonitorCallback(Callback):
"""Monitor gradient norms."""
def on_after_backward(self, trainer, pl_module):
# Compute gradient norm
total_norm = 0.0
for p in pl_module.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
# Log
pl_module.log('grad_norm', total_norm)
# Warn if exploding
if total_norm > 100:
print(f"Warning: Large gradient norm: {total_norm:.2f}")
Model Inspection Callback
class ModelInspectionCallback(Callback):
"""Inspect model activations during training."""
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if batch_idx == 0: # First batch of epoch
# Register hooks
self.activations = {}
def get_activation(name):
def hook(model, input, output):
self.activations[name] = output.detach()
return hook
# Attach to specific layers
pl_module.model.layer1.register_forward_hook(get_activation('layer1'))
pl_module.model.layer2.register_forward_hook(get_activation('layer2'))
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx == 0:
# Log activation statistics
for name, activation in self.activations.items():
mean = activation.mean().item()
std = activation.std().item()
pl_module.log(f'{name}_mean', mean)
pl_module.log(f'{name}_std', std)
Callback Hooks
All available hooks:
class MyCallback(Callback):
# Setup/Teardown
def setup(self, trainer, pl_module, stage):
"""Called at beginning of fit/test/predict."""
pass
def teardown(self, trainer, pl_module, stage):
"""Called at end of fit/test/predict."""
pass
# Training
def on_train_start(self, trainer, pl_module):
pass
def on_train_epoch_start(self, trainer, pl_module):
pass
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
pass
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
pass
def on_train_epoch_end(self, trainer, pl_module):
pass
def on_train_end(self, trainer, pl_module):
pass
# Validation
def on_validation_start(self, trainer, pl_module):
pass
def on_validation_epoch_start(self, trainer, pl_module):
pass
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
pass
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
pass
def on_validation_epoch_end(self, trainer, pl_module):
pass
def on_validation_end(self, trainer, pl_module):
pass
# Test (same structure as validation)
def on_test_start(self, trainer, pl_module):
pass
# ... (test_epoch_start, test_batch_start, etc.)
# Predict
def on_predict_start(self, trainer, pl_module):
pass
# ... (predict_epoch_start, predict_batch_start, etc.)
# Backward
def on_before_backward(self, trainer, pl_module, loss):
pass
def on_after_backward(self, trainer, pl_module):
pass
# Optimizer
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
pass
# Checkpointing
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
"""Add data to checkpoint."""
pass
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
"""Restore data from checkpoint."""
pass
Combining Multiple Callbacks
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
# Create all callbacks
checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=3)
early_stop = EarlyStopping(monitor='val_loss', patience=5)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
custom_callback = MyCustomCallback()
# Add all to Trainer
trainer = L.Trainer(
callbacks=[checkpoint, early_stop, lr_monitor, custom_callback]
)
trainer.fit(model, train_loader, val_loader)
Execution order: Callbacks execute in the order they're added
Best Practices
1. Keep Callbacks Independent
Bad (dependent on other callback):
class BadCallback(Callback):
def on_train_end(self, trainer, pl_module):
# Assumes ModelCheckpoint is present
best_path = trainer.checkpoint_callback.best_model_path # Fragile!
Good (self-contained):
class GoodCallback(Callback):
def on_train_end(self, trainer, pl_module):
# Find checkpoint callback if present
for callback in trainer.callbacks:
if isinstance(callback, ModelCheckpoint):
best_path = callback.best_model_path
break
2. Use State Dict for Persistence
class StatefulCallback(Callback):
def __init__(self):
self.counter = 0
self.history = []
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.counter += 1
self.history.append(outputs['loss'].item())
def state_dict(self):
"""Save state."""
return {
'counter': self.counter,
'history': self.history
}
def load_state_dict(self, state_dict):
"""Restore state."""
self.counter = state_dict['counter']
self.history = state_dict['history']
3. Handle Distributed Training
class DistributedCallback(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
# Only run on main process
if trainer.is_global_zero:
print("This only prints once in distributed training")
# Run on all processes
loss = outputs['loss']
# ... do something with loss on each GPU