LLaMA 3 모델로 FSDP 학습하기
메타 라마3가 출시되가 되어서 meta-llama/Llama-3-8B 모델을 가지고 FSDP 방식으로 학습하는 전체 과정을 단계별로 정리해보았다. 이 튜토리얼은 PyTorch 기반이며, 주로 Hugging Face Transformers와 🤗 Accelerate 없이 직접 FSDP를 구성하는 방식이다.
1. 전제 조건
- PyTorch >= 2.0
- GPU 2개 이상 (예: A100, H100)
- NCCL backend
- 모델:
meta-llama/Llama-3-8B
또는Llama-3-3B
- OS: Ubuntu, Python 3.10+
2. 전체 구성 요약
- 환경 설정 (Public Cloud / On-Prem 등) 및 실행
- 모델 로딩 및 토크나이저 설정
- FSDP 초기화 및 구현
- 데이터셋 로딩 및 샘플러 구성
- 학습 루프 구현 (forward, loss, backward)
- FSDP에서의 저장 및 로딩
3. 환경 설정 및 실행
- 2개 GPU 설정
# 예시: torchrun 방식 (NCCL backend)
torchrun \
--nproc_per_node=2 \
--nnodes=1 \
--rdzv_id=101 \
--rdzv_backend=c10d \
--rdzv_endpoint=localhost:29500 \
llama3_fsdp_train.py
3. 라마3 모델 및 토큰나이저 실행
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "meta-llama/Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
4. FSDP 초기화 및 구현
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
def fsdp_wrap_model(model, rank):
auto_wrap_policy = transformer_auto_wrap_policy({LlamaDecoderLayer})
model = model.to(rank)
fsdp_model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
device_id=rank,
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
mixed_precision=torch.distributed.fsdp.MixedPrecision(param_dtype=torch.bfloat16)
)
return fsdp_model
5. 데이터셋 로딩
from datasets import load_dataset
# Hugging Face에 있는 allganize 의 Finanacial dataset
dataset = load_dataset("allganize/financial-mmlu-ko", split="train")
def tokenize_function(example):
return tokenizer(example["input"], truncation=True, padding="max_length", max_length=512)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
6. Dataloader 및 Sampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
sampler = DistributedSampler(tokenized_dataset)
dataloader = DataLoader(tokenized_dataset, batch_size=2, sampler=sampler)
7. 학습 루프 (FSDP + Optimizer)
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(3):
for batch in dataloader:
input_ids = batch["input_ids"].to(rank)
labels = batch["input_ids"].to(rank) # causal LM
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
if dist.get_rank() == 0:
print(f"Loss: {loss.item()}")
8. FSDP 모델 저장
# 저장
torch.distributed.barrier()
FSDP.save_model(model, "llama3-fsdp-checkpoint/")
9. FSDP 모델 로드
# 로딩
fsdp_model = FSDP(model, device=rank)
FSDP.load_model(fsdp_model, "llama3-fsdp-checkpoint/")
10. 추가 팁
activation_checkpointing
을 FSDP와 함께 사용하면 VRAM 사용량을 더 줄일 수 있음- FSDP는 3D parallelism (Tensor + Pipeline + Data)과도 병행 가능
megatron-lm
,fairscale
,DeepSpeed
등과의 비교는 옵션에 따라 성능 차이 존재
댓글남기기