DDP 개념과 MLDE 지원여부
DDP(Distributed Data Parallel)는 PyTorch에서 제공하는 분산 학습(Distributed Training) 방식 중 하나로, 여러 개의 GPU 혹은 여러 노드에서 데이터를 병렬로 처리하며 모델을 학습하는 방법이다.
다시 말해, DDP는 각 GPU에 모델 복사본을 놓고, 입력 데이터를 나눠 처리한 뒤, 각자의 그래디언트를 동기화해서 전체 모델을 훈련시키는 방식이다. 그렇다면, DDP에 대해 알아보자!
1. DDP의 핵심 개념
구성 요소 | 설명 |
---|---|
모델 복제 (Replica) | 각 GPU마다 동일한 모델을 복사하여 실행 |
데이터 분산 (Sharding) | 전체 데이터를 GPU 수만큼 나누어 각 GPU에 전달 |
동기화 (Sync) | 각 GPU가 계산한 그래디언트를 AllReduce로 동기화 |
프로세스 기반 실행 | GPU마다 하나의 프로세스를 생성하여 독립 실행 (멀티 프로세스) |
2. DDP 학습 플로우
- 모든 GPU에 동일한 모델을 로드
- 학습 데이터 배치를 GPU 수만큼 나누어 각 GPU에 전달
- 각 GPU가 forward → loss → backward 계산 수행
- 각 GPU의 gradient를 AllReduce 연산으로 평균
- optimizer가 동기화된 gradient로 파라미터 업데이트
- 다음 step 반복
3. DDP와 DataParallel 차이
항목 | torch.nn.DataParallel |
torch.nn.parallel.DistributedDataParallel (DDP) |
---|---|---|
실행 방식 | 한 프로세스, 멀티 스레드 | 멀티 프로세스, 멀티 GPU |
성능 | 낮음 (GIL, 통신 오버헤드 존재) | 높음 (병렬화 효율적, 통신 최적화) |
권장 환경 | 간단한 실험용 코드 | 실전 분산 학습 환경 |
4. DDP 예제 코드
- 모든 GPU에 동일한 모델을 로드
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size):
setup(rank, world_size)
model = MyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
# loss, optimizer 정의 후 학습
...
cleanup()
5. DDP 특징
- GPU당 하나의 프로세스를 사용하는 멀티 프로세스 모델
- gradient 동기화는
AllReduce
로 처리됨 torchrun
또는torch.distributed.launch
등을 통해 실행- 분산 데이터로 ader 사용 필수 (
DistributedSampler
)
6. MLDE에서 DDP 지원
-
DDP는 특히 대규모 모델 훈련, 멀티 GPU/멀티 노드 환경에서 성능 효율이 매우 뛰어나기 때문에, PyTorch에서는 사실상 표준 분산 학습 방법으로 간주됨.
-
MLDE에서 DDP를 지원하는지에 대한 공식적인 정보는 확인되지 않음. 그러나 MLDE가 PyTorch 기반의 환경을 제공한다면, PyTorch의 DDP 기능을 활용하여 분산 학습을 구현할 수 있음.
-
PyTorch의 DDP는 멀티 프로세스 기반으로 각 GPU에 모델을 복제하고, 각 프로세스에서 독립적으로 학습을 진행한 후, gradient를 동기화하여 모델을 업데이트하는 방식. GIL(Global Interpreter Lock) 문제를 피하고, 통신 오버헤드를 줄여 효율적인 분산 학습을 가능하게 함.
-
DDP를 사용하기 위해서는
torch.distributed
패키지를 활용하여 프로세스 그룹을 초기화하고, 각 프로세스에 모델을 할당한 후,DistributedDataParallel
클래스로 모델을 감싸는 등의 설정이 필요함
댓글남기기