hermes-agent/skills/mlops/training/torchtitan/references/fsdp.md
teknium1 732c66b0f3 refactor: reorganize skills into sub-categories
The skills directory was getting disorganized — mlops alone had 40
skills in a flat list, and 12 categories were singletons with just
one skill each.

Code change:
- prompt_builder.py: Support sub-categories in skill scanner.
  skills/mlops/training/axolotl/SKILL.md now shows as category
  'mlops/training' instead of just 'mlops'. Backwards-compatible
  with existing flat structure.

Split mlops (40 skills) into 7 sub-categories:
- mlops/training (12): accelerate, axolotl, flash-attention,
  grpo-rl-training, peft, pytorch-fsdp, pytorch-lightning,
  simpo, slime, torchtitan, trl-fine-tuning, unsloth
- mlops/inference (8): gguf, guidance, instructor, llama-cpp,
  obliteratus, outlines, tensorrt-llm, vllm
- mlops/models (6): audiocraft, clip, llava, segment-anything,
  stable-diffusion, whisper
- mlops/vector-databases (4): chroma, faiss, pinecone, qdrant
- mlops/evaluation (5): huggingface-tokenizers,
  lm-evaluation-harness, nemo-curator, saelens, weights-and-biases
- mlops/cloud (2): lambda-labs, modal
- mlops/research (1): dspy

Merged singleton categories:
- gifs → media (gif-search joins youtube-content)
- music-creation → media (heartmula, songsee)
- diagramming → creative (excalidraw joins ascii-art)
- ocr-and-documents → productivity
- domain → research (domain-intel)
- feeds → research (blogwatcher)
- market-data → research (polymarket)

Fixed misplaced skills:
- mlops/code-review → software-development (not ML-specific)
- mlops/ml-paper-writing → research (academic writing)

Added DESCRIPTION.md files for all new/updated categories.
2026-03-09 03:35:53 -07:00

3.8 KiB

FSDP2 in TorchTitan

Why FSDP2?

FSDP2 is a rewrite of PyTorch's Fully Sharded Data Parallel (FSDP) API, removing the FlatParameter abstraction for better composability and simpler implementation.

Key improvements over FSDP1

  • DTensor-based sharding: Sharded parameters are DTensors on dim-0, enabling easy manipulation and communication-free sharded state dicts
  • Better memory management: Deterministic and lower GPU memory (7% reduction) by avoiding recordStream
  • Simplified API: Fewer arguments, no wrapper class

Performance

On Llama-7B with 8x H100s, FSDP2 achieves higher MFU with 7% lower peak memory than FSDP1, matching the same loss curve.

API Reference

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, OffloadPolicy

@contract(state_cls=FSDPState)
def fully_shard(
    module: nn.Module,
    *,
    mesh: Optional[DeviceMesh] = None,
    reshard_after_forward: Union[bool, int] = True,
    mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
    offload_policy: OffloadPolicy = OffloadPolicy(),
) -> nn.Module:

Sharding Strategies (ZeRO Equivalents)

FSDP2 Configuration FSDP1 Equivalent DeepSpeed
1D mesh + reshard_after_forward=True FULL_SHARD ZeRO-3
1D mesh + reshard_after_forward=False SHARD_GRAD_OP ZeRO-2
2D mesh + reshard_after_forward=True HYBRID_SHARD MiCS
1D/2D mesh + reshard_after_forward=8 (int) - ZeRO++ hpZ

Meta-Device Initialization

FSDP2 supports materializing tensors onto GPU after sharding:

# Initialize on meta device (no memory)
with torch.device("meta"):
    model = Transformer()

# Apply FSDP2 sharding
for module in model.modules():
    if isinstance(module, TransformerBlock):
        fully_shard(module)
fully_shard(model)

# Parameters still on meta device
for tensor in itertools.chain(model.parameters(), model.buffers()):
    assert tensor.device == torch.device("meta")

# Allocate sharded parameters on GPU
model.to_empty(device="cuda")

# Initialize weights
model.init_weights()

State Dict Differences

Operation FSDP1 FSDP2
model.state_dict() Full state dict Sharded state dict (no communication)
optim.state_dict() Local state dict Sharded state dict (no communication)
summon_full_params() Supported Use DTensor APIs like full_tensor()
Gradient clipping FSDP.clip_grad_norm_() nn.utils.clip_grad_norm_()

Mixed Precision

from torch.distributed._composable.fsdp import MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    output_dtype=torch.bfloat16,
    cast_forward_inputs=True,
)

fully_shard(model, mp_policy=mp_policy)

HSDP (Hybrid Sharded Data Parallel)

For 2D parallelism with replication + sharding:

from torch.distributed.device_mesh import init_device_mesh

# Replicate across 4 groups, shard within 8 GPUs each
mesh = init_device_mesh("cuda", (4, 8), mesh_dim_names=("replicate", "shard"))

fully_shard(model, mesh=mesh)

Configuration in TorchTitan

[parallelism]
# FSDP sharding degree (-1 = auto, use all available GPUs)
data_parallel_shard_degree = -1

# HSDP replication degree (1 = pure FSDP, >1 = HSDP)
data_parallel_replicate_degree = 1

Removed Arguments from FSDP1

These FSDP1 arguments are no longer needed:

  • auto_wrap_policy: Apply fully_shard directly to modules
  • backward_prefetch: Always uses BACKWARD_PRE
  • param_init_fn: Use meta-device initialization
  • device_id: Uses mesh's device automatically
  • sync_module_states: Not needed with DTensor
  • limit_all_gathers: New memory management doesn't need it
  • use_orig_params: Always true (no FlatParameter)