1 분 소요

메타 라마3가 출시되가 되어서 meta-llama/Llama-3-8B 모델을 가지고 FSDP 방식으로 학습하는 전체 과정을 단계별로 정리해보았다. 이 튜토리얼은 PyTorch 기반이며, 주로 Hugging Face Transformers와 🤗 Accelerate 없이 직접 FSDP를 구성하는 방식이다.

1. 전제 조건

  • PyTorch >= 2.0
  • GPU 2개 이상 (예: A100, H100)
  • NCCL backend
  • 모델: meta-llama/Llama-3-8B 또는 Llama-3-3B
  • OS: Ubuntu, Python 3.10+

2. 전체 구성 요약

  • 환경 설정 (Public Cloud / On-Prem 등) 및 실행
  • 모델 로딩 및 토크나이저 설정
  • FSDP 초기화 및 구현
  • 데이터셋 로딩 및 샘플러 구성
  • 학습 루프 구현 (forward, loss, backward)
  • FSDP에서의 저장 및 로딩

3. 환경 설정 및 실행

  • 2개 GPU 설정
# 예시: torchrun 방식 (NCCL backend)
torchrun \
  --nproc_per_node=2 \
  --nnodes=1 \
  --rdzv_id=101 \
  --rdzv_backend=c10d \
  --rdzv_endpoint=localhost:29500 \
  llama3_fsdp_train.py

3. 라마3 모델 및 토큰나이저 실행

from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

4. FSDP 초기화 및 구현

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

def fsdp_wrap_model(model, rank):
    auto_wrap_policy = transformer_auto_wrap_policy({LlamaDecoderLayer})
    model = model.to(rank)
    fsdp_model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        device_id=rank,
        sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
        mixed_precision=torch.distributed.fsdp.MixedPrecision(param_dtype=torch.bfloat16)
    )
    return fsdp_model

5. 데이터셋 로딩

from datasets import load_dataset

# Hugging Face에 있는 allganize 의 Finanacial dataset
dataset = load_dataset("allganize/financial-mmlu-ko", split="train")
def tokenize_function(example):
    return tokenizer(example["input"], truncation=True, padding="max_length", max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True)

6. Dataloader 및 Sampler

from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader

sampler = DistributedSampler(tokenized_dataset)
dataloader = DataLoader(tokenized_dataset, batch_size=2, sampler=sampler)

7. 학습 루프 (FSDP + Optimizer)

from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = torch.nn.CrossEntropyLoss()

model.train()
for epoch in range(3):
    for batch in dataloader:
        input_ids = batch["input_ids"].to(rank)
        labels = batch["input_ids"].to(rank)  # causal LM

        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if dist.get_rank() == 0:
            print(f"Loss: {loss.item()}")

8. FSDP 모델 저장

# 저장
torch.distributed.barrier()
FSDP.save_model(model, "llama3-fsdp-checkpoint/")

9. FSDP 모델 로드

# 로딩
fsdp_model = FSDP(model, device=rank)
FSDP.load_model(fsdp_model, "llama3-fsdp-checkpoint/")

10. 추가 팁

  • activation_checkpointing을 FSDP와 함께 사용하면 VRAM 사용량을 더 줄일 수 있음
  • FSDP는 3D parallelism (Tensor + Pipeline + Data)과도 병행 가능
  • megatron-lm, fairscale, DeepSpeed 등과의 비교는 옵션에 따라 성능 차이 존재

댓글남기기