mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-26 01:01:40 +00:00
Comprehensive cleanup across 80 files based on automated (ruff, pyflakes, vulture)
and manual analysis of the entire codebase.
Changes by category:
Unused imports removed (~95 across 55 files):
- Removed genuinely unused imports from all major subsystems
- agent/, hermes_cli/, tools/, gateway/, plugins/, cron/
- Includes imports in try/except blocks that were truly unused
(vs availability checks which were left alone)
Unused variables removed (~25):
- Removed dead variables: connected, inner, channels, last_exc,
source, new_server_names, verify, pconfig, default_terminal,
result, pending_handled, temperature, loop
- Dropped unused argparse subparser assignments in hermes_cli/main.py
(12 instances of add_parser() where result was never used)
Dead code removed:
- run_agent.py: Removed dead ternary (None if False else None) and
surrounding unreachable branch in identity fallback
- run_agent.py: Removed write-only attribute _last_reported_tool
- hermes_cli/providers.py: Removed dead @property decorator on
module-level function (decorator has no effect outside a class)
- gateway/run.py: Removed unused MCP config load before reconnect
- gateway/platforms/slack.py: Removed dead SessionSource construction
Undefined name bugs fixed (would cause NameError at runtime):
- batch_runner.py: Added missing logger = logging.getLogger(__name__)
- tools/environments/daytona.py: Added missing Dict and Path imports
Unnecessary global statements removed (14):
- tools/terminal_tool.py: 5 functions declared global for dicts
they only mutated via .pop()/[key]=value (no rebinding)
- tools/browser_tool.py: cleanup thread loop only reads flag
- tools/rl_training_tool.py: 4 functions only do dict mutations
- tools/mcp_oauth.py: only reads the global
- hermes_time.py: only reads cached values
Inefficient patterns fixed:
- startswith/endswith tuple form: 15 instances of
x.startswith('a') or x.startswith('b') consolidated to
x.startswith(('a', 'b'))
- len(x)==0 / len(x)>0: 13 instances replaced with pythonic
truthiness checks (not x / bool(x))
- in dict.keys(): 5 instances simplified to in dict
- Redefined unused name: removed duplicate _strip_mdv2 import in
send_message_tool.py
Other fixes:
- hermes_cli/doctor.py: Replaced undefined logger.debug() with pass
- hermes_cli/config.py: Consolidated chained .endswith() calls
Test results: 3934 passed, 17 failed (all pre-existing on main),
19 skipped. Zero regressions.
228 lines
6 KiB
Python
228 lines
6 KiB
Python
"""
|
|
Basic GRPO Training Template
|
|
=============================
|
|
|
|
A minimal, production-ready template for GRPO training with TRL.
|
|
Adapt this for your specific task by modifying:
|
|
1. Dataset loading (get_dataset function)
|
|
2. Reward functions (reward_*_func)
|
|
3. System prompt (SYSTEM_PROMPT)
|
|
4. Hyperparameters (GRPOConfig)
|
|
"""
|
|
|
|
import torch
|
|
import re
|
|
from datasets import load_dataset
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from peft import LoraConfig
|
|
from trl import GRPOTrainer, GRPOConfig
|
|
|
|
# ==================== CONFIGURATION ====================
|
|
|
|
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
|
OUTPUT_DIR = "outputs/grpo-model"
|
|
MAX_PROMPT_LENGTH = 256
|
|
MAX_COMPLETION_LENGTH = 512
|
|
|
|
SYSTEM_PROMPT = """
|
|
Respond in the following format:
|
|
<reasoning>
|
|
[Your step-by-step thinking]
|
|
</reasoning>
|
|
<answer>
|
|
[Final answer]
|
|
</answer>
|
|
"""
|
|
|
|
# ==================== DATASET ====================
|
|
|
|
def get_dataset(split="train"):
|
|
"""
|
|
Load and prepare your dataset.
|
|
|
|
Returns: Dataset with columns:
|
|
- 'prompt': List[Dict] with role/content
|
|
- 'answer': str (ground truth, optional)
|
|
"""
|
|
# Example: GSM8K math dataset
|
|
data = load_dataset('openai/gsm8k', 'main')[split]
|
|
|
|
def process_example(x):
|
|
# Extract ground truth answer
|
|
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
|
|
|
|
return {
|
|
'prompt': [
|
|
{'role': 'system', 'content': SYSTEM_PROMPT},
|
|
{'role': 'user', 'content': x['question']}
|
|
],
|
|
'answer': answer
|
|
}
|
|
|
|
return data.map(process_example)
|
|
|
|
# ==================== HELPER FUNCTIONS ====================
|
|
|
|
def extract_xml_tag(text: str, tag: str) -> str:
|
|
"""Extract content between XML tags."""
|
|
pattern = f'<{tag}>(.*?)</{tag}>'
|
|
match = re.search(pattern, text, re.DOTALL)
|
|
return match.group(1).strip() if match else ""
|
|
|
|
def extract_answer(text: str) -> str:
|
|
"""Extract the final answer from structured output."""
|
|
return extract_xml_tag(text, 'answer')
|
|
|
|
# ==================== REWARD FUNCTIONS ====================
|
|
|
|
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
|
"""
|
|
Reward correct answers.
|
|
Weight: 2.0 (highest priority)
|
|
"""
|
|
responses = [comp[0]['content'] for comp in completions]
|
|
extracted = [extract_answer(r) for r in responses]
|
|
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
|
|
|
|
def format_reward_func(completions, **kwargs):
|
|
"""
|
|
Reward proper XML format.
|
|
Weight: 0.5
|
|
"""
|
|
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
|
responses = [comp[0]['content'] for comp in completions]
|
|
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
|
|
|
|
def incremental_format_reward_func(completions, **kwargs):
|
|
"""
|
|
Incremental reward for partial format compliance.
|
|
Weight: up to 0.5
|
|
"""
|
|
responses = [comp[0]['content'] for comp in completions]
|
|
rewards = []
|
|
|
|
for r in responses:
|
|
score = 0.0
|
|
if '<reasoning>' in r:
|
|
score += 0.125
|
|
if '</reasoning>' in r:
|
|
score += 0.125
|
|
if '<answer>' in r:
|
|
score += 0.125
|
|
if '</answer>' in r:
|
|
score += 0.125
|
|
|
|
# Penalize extra content after closing tag
|
|
if '</answer>' in r:
|
|
extra = r.split('</answer>')[-1].strip()
|
|
score -= len(extra) * 0.001
|
|
|
|
rewards.append(score)
|
|
|
|
return rewards
|
|
|
|
# ==================== MODEL SETUP ====================
|
|
|
|
def setup_model_and_tokenizer():
|
|
"""Load model and tokenizer with optimizations."""
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_NAME,
|
|
torch_dtype=torch.bfloat16,
|
|
attn_implementation="flash_attention_2",
|
|
device_map="auto"
|
|
)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
return model, tokenizer
|
|
|
|
def get_peft_config():
|
|
"""LoRA configuration for parameter-efficient training."""
|
|
return LoraConfig(
|
|
r=16,
|
|
lora_alpha=32,
|
|
target_modules=[
|
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
"gate_proj", "up_proj", "down_proj"
|
|
],
|
|
task_type="CAUSAL_LM",
|
|
lora_dropout=0.05,
|
|
)
|
|
|
|
# ==================== TRAINING ====================
|
|
|
|
def main():
|
|
"""Main training function."""
|
|
|
|
# Load data
|
|
print("Loading dataset...")
|
|
dataset = get_dataset()
|
|
print(f"Dataset size: {len(dataset)}")
|
|
|
|
# Setup model
|
|
print("Loading model...")
|
|
model, tokenizer = setup_model_and_tokenizer()
|
|
|
|
# Training configuration
|
|
training_args = GRPOConfig(
|
|
output_dir=OUTPUT_DIR,
|
|
run_name="grpo-training",
|
|
|
|
# Learning rate
|
|
learning_rate=5e-6,
|
|
adam_beta1=0.9,
|
|
adam_beta2=0.99,
|
|
weight_decay=0.1,
|
|
warmup_ratio=0.1,
|
|
lr_scheduler_type='cosine',
|
|
|
|
# Batch settings
|
|
per_device_train_batch_size=1,
|
|
gradient_accumulation_steps=4,
|
|
|
|
# GRPO specific
|
|
num_generations=8,
|
|
max_prompt_length=MAX_PROMPT_LENGTH,
|
|
max_completion_length=MAX_COMPLETION_LENGTH,
|
|
|
|
# Training duration
|
|
num_train_epochs=1,
|
|
|
|
# Optimization
|
|
bf16=True,
|
|
optim="adamw_8bit",
|
|
max_grad_norm=0.1,
|
|
|
|
# Logging
|
|
logging_steps=1,
|
|
save_steps=100,
|
|
report_to="wandb", # Change to "none" to disable logging
|
|
)
|
|
|
|
# Initialize trainer
|
|
trainer = GRPOTrainer(
|
|
model=model,
|
|
processing_class=tokenizer,
|
|
reward_funcs=[
|
|
incremental_format_reward_func,
|
|
format_reward_func,
|
|
correctness_reward_func,
|
|
],
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
peft_config=get_peft_config(),
|
|
)
|
|
|
|
# Train
|
|
print("Starting training...")
|
|
trainer.train()
|
|
|
|
# Save final model
|
|
print(f"Saving model to {OUTPUT_DIR}/final")
|
|
trainer.save_model(f"{OUTPUT_DIR}/final")
|
|
|
|
print("Training complete!")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|