1 분 소요

DDP(Distributed Data Parallel)는 PyTorch에서 제공하는 분산 학습(Distributed Training) 방식 중 하나로, 여러 개의 GPU 혹은 여러 노드에서 데이터를 병렬로 처리하며 모델을 학습하는 방법이다.

다시 말해, DDP는 각 GPU에 모델 복사본을 놓고, 입력 데이터를 나눠 처리한 뒤, 각자의 그래디언트를 동기화해서 전체 모델을 훈련시키는 방식이다. 그렇다면, DDP에 대해 알아보자!

1. DDP의 핵심 개념

구성 요소 설명
모델 복제 (Replica) 각 GPU마다 동일한 모델을 복사하여 실행
데이터 분산 (Sharding) 전체 데이터를 GPU 수만큼 나누어 각 GPU에 전달
동기화 (Sync) 각 GPU가 계산한 그래디언트를 AllReduce로 동기화
프로세스 기반 실행 GPU마다 하나의 프로세스를 생성하여 독립 실행 (멀티 프로세스)

2. DDP 학습 플로우

  • 모든 GPU에 동일한 모델을 로드
  • 학습 데이터 배치를 GPU 수만큼 나누어 각 GPU에 전달
  • 각 GPU가 forward → loss → backward 계산 수행
  • 각 GPU의 gradient를 AllReduce 연산으로 평균
  • optimizer가 동기화된 gradient로 파라미터 업데이트
  • 다음 step 반복

3. DDP와 DataParallel 차이

항목 torch.nn.DataParallel torch.nn.parallel.DistributedDataParallel (DDP)
실행 방식 한 프로세스, 멀티 스레드 멀티 프로세스, 멀티 GPU
성능 낮음 (GIL, 통신 오버헤드 존재) 높음 (병렬화 효율적, 통신 최적화)
권장 환경 간단한 실험용 코드 실전 분산 학습 환경

4. DDP 예제 코드

  • 모든 GPU에 동일한 모델을 로드
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)
    
    model = MyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    # loss, optimizer 정의 후 학습
    ...

    cleanup()

5. DDP 특징

  • GPU당 하나의 프로세스를 사용하는 멀티 프로세스 모델
  • gradient 동기화는 AllReduce로 처리됨
  • torchrun 또는 torch.distributed.launch 등을 통해 실행
  • 분산 데이터로 ader 사용 필수 (DistributedSampler)

6. MLDE에서 DDP 지원

  • DDP는 특히 대규모 모델 훈련, 멀티 GPU/멀티 노드 환경에서 성능 효율이 매우 뛰어나기 때문에, PyTorch에서는 사실상 표준 분산 학습 방법으로 간주됨.

  • MLDE에서 DDP를 지원하는지에 대한 공식적인 정보는 확인되지 않음. 그러나 MLDE가 PyTorch 기반의 환경을 제공한다면, PyTorch의 DDP 기능을 활용하여 분산 학습을 구현할 수 있음.

  • PyTorch의 DDP는 멀티 프로세스 기반으로 각 GPU에 모델을 복제하고, 각 프로세스에서 독립적으로 학습을 진행한 후, gradient를 동기화하여 모델을 업데이트하는 방식. GIL(Global Interpreter Lock) 문제를 피하고, 통신 오버헤드를 줄여 효율적인 분산 학습을 가능하게 함.

  • DDP를 사용하기 위해서는 torch.distributed 패키지를 활용하여 프로세스 그룹을 초기화하고, 각 프로세스에 모델을 할당한 후, DistributedDataParallel 클래스로 모델을 감싸는 등의 설정이 필요함

댓글남기기