FSDP란 무엇인가
FSDP(Fully Sharded Data Parallel)는 PyTorch에서 제공하는 고성능 분산 학습 기법으로, 모델의 파라미터, gradient, optimizer state를 GPU 간에 완전히(sharded) 나눠서 메모리 효율적으로 학습하는 방식이다.
따라서, FSDP는 모델 파라미터와 학습 상태를 GPU 간에 분할 저장하고, 필요한 시점에만 통신과 계산을 수행하는 PyTorch의 메모리 효율적 분산 학습 기법이다. 그러면, DDP 와 FSDP를 비교해보자!
1. DDP vs FSDP 차이점
항목 | DDP (DistributedDataParallel) | FSDP (FullyShardedDataParallel) |
---|---|---|
파라미터 저장 방식 | 각 GPU에 전체 모델 복사 | 파라미터를 GPU마다 분산 저장 (Sharding) |
메모리 사용량 | 높음 | 매우 효율적 (최대 8배 절감) |
통신 시점 | backward 중 AllReduce | forward/backward 시점에서 AllGather/ReduceScatter |
초기화 방식 | 전체 모델 로딩 후 wrapping | 모듈 단위 wrapping 가능, lazy loading 가능 |
권장 사용 | 중소 모델 | 초거대 모델 (LLM 등) |
2. 학습 동작 플로우
- 모델 분할: 각 GPU는 전체 모델의 일부 파라미터만 저장함 (sharding)
- Forward 시점에 AllGather: 연산에 필요한 파라미터를 AllGather로 잠시 모음
- Backward 시점에 ReduceScatter: gradient 계산 후, 필요한 부분만 다시 나눠서 저장
- Optimizer State도 Sharding: 메모리 효율 극대화
3. FSDP가 중요한 이유
- LLaMA, GPT, T5 같은 초거대 모델 훈련 시 필수
- 단일 GPU에 올라가지 않는 모델도 메모리 분산으로 훈련 가능
- HuggingFace, Meta AI, MosaicML 등의 LLM 학습 코드에서 적극 사용 중
4. FSDP 예제 코드
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.distributed as dist
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def main(rank, world_size):
setup(rank, world_size)
model = MyLargeModel().to(rank)
model = FSDP(model, device_id=rank)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()
for batch in dataloader:
outputs = model(batch['input'])
loss = loss_fn(outputs, batch['target'])
loss.backward()
optimizer.step()
optimizer.zero_grad()
5. FSDP에 적합한 경우
-
모델이 단일 GPU VRAM에 안 들어가는 경우
-
모델 파라미터 수가 수십억 이상인 경우
-
Checkpointing, activation recomputation 등과 함께 쓰는 경우
댓글남기기