5 분 소요

지난 주 블로글에서 셀러브래스(Cerebras)와 그록(Groq) 회사가 Llama API를 통해 더 빠른 추론 속도를 서비스를 제공하겠다는 라마콘 소식을 올렸는 데, 오늘은 UC 버클리와 UCSF 연구진이 공동으로 발표한 병렬로 LLM 추론하는 APR에 대한 논문을 읽고 흥미로워서 오랜만에 논문을 요점정리하는 블로그를 진행하고자 한다.

1. 논문이 나오게 된 배경

  • 대규모 언어 모델(LLM)은 OpenAI의 o1이나 DeepSeekR1처럼 Test-time compute을 활용한 검색 및 강화학습을 통해 성능을 최적화하듯 추론 영역에서도 큰 진보를 이끌어 냄.
  • 직렬 체인 오브 싱킹(Serialized Chain-of-Thought) 방식은 출력 시퀀스가 지나치게 길어져 지연(latency)이 증가하고, 컨텍스트 윈도우 한계를 초과하는 문제를 유발함.
  • Best-of-N, Self-Consistency와 같은 병렬 추론 기법은 각 추론 경로 간의 조정이 부족하고, 엔드 투 엔드 최적화가 이뤄지지 않아 계산 효율성이 떨어지며 성능 향상에도 제한함.
  • Tree-of-Thought와 같은 구조화된 추론 기법은 사람이 설계한 검색 구조에 의존하므로, 다양한 추론 작업이나 도메인에 걸쳐 확장성과 유연성이 떨어짐.

2. LLM 추론의 계산 문제 해결 방안의 문제점

  • 추론 시점에서 계산량을 늘려 성능을 높이는 방법: 출력 시퀀스를 지나치게 길게 만들어 지연을 초래하고, 모든 추론 경로를 하나의 컨텍스트 윈도우에 넣어야 하는 부담을 줌.
  • 앙상블 기반의 병렬화 전략: 여러 모델 호출을 동시에 실행하지만, 경로 간의 조정이 미흡해 중복 계산과 비효율적 자원 활용
  • Tree-of-Thought나 멀티 에이전트 방식: 병렬 추론 구조를 고정적으로 설계하지만, 확장성과 유연성에 한계
  • PASTA 접근 방식: 작업을 병렬 하위 작업으로 나누지만, 결국 전체 컨텍스트를 메인 추론 경로로 재통합함으로써 컨텍스트 사용량 감소에 실패함.
  • Hogwild! 추론: 병렬 워커 스레드를 활용하지만, 단순 프롬프트 기반이며 엔드 투 엔드 최적화는 적용되지 않음

3. 적응형 병렬 추론(APR)의 탄생

  • 적응형 병렬 추론(APR, Adaptive Parallel Reasoning)
    • 언어 모델이 추론 시점에서 직렬 연산과 병렬 연산을 동적으로 분배할 수 있도록 설계됨.
    • 기존의 직렬 체인 오브 싱킹, 병렬 자기 일관성(Self-Consistency) 추론, 구조화된 검색 등을 일반화하여, 모델이 언제, 어떻게 병렬화할지 스스로 결정하도록 훈련
  • 부모-자식 스레딩 메커니즘
    • 부모 추론 스레드가 spawn() 연산을 통해 여러 자식 스레드에 하위 작업을 병렬로 분배
    • 각 자식 스레드는 결과를 join() 연산을 통해 부모에게 반환함으로써 추론을 이어감
    • SGLang 모델 서빙 프레임워크 위에 구축되어, 자식 스레드에서의 병렬 추론을 배치 처리로 효율적으로 실행하여 실시간 지연을 대폭 줄임.
  • 엔드 투 엔드 강화학습 기반의 파인튜닝
    • 미리 정의된 추론 구조 없이, 작업 성공률을 극대화하기 위해 전체 추론 과정을 최적화함

4. 적응형 병렬 추론 장점

  • 고정된 컨텍스트 윈도우 내에서 더 높은 성능을 달성
  • 계산 자원이 늘어날수록 성능이 안정적으로 확장
  • 동일한 지연 시간에서 기존 방식보다 우수한 추론 결과

5. 적응형 병렬 추론 아키텍처

그림 1 - 직렬 탐색 방식(Gandhi 외, 2024)(위) vs. 적응형 병렬 추론 방식 (아래)

[그림 1 - 직렬 탐색 방식 (Gandhi 외, 2024)(위) vs. 적응형 병렬 추론 방식 (아래)]
  • 언어 모델이 병렬 추론 과정을 동적으로 조정할 수 있도록 하는 정교한 멀티스레딩 메커니즘을 구현
  • 직렬적 추론 방식의 한계를 극복하고, 부모-자식 스레드 간에 계산을 분산시켜 지연(latency)을 최소화하고, 컨텍스트 제약 내에서 성능을 극대화
  • 단순한 구조화 추론을 넘어, LLM의 병렬 추론을 학습 가능한 정책으로 확장함으로써 고속, 고성능, 저비용 추론을 실현할 수 있도록 설계되었음.

그림 2 - 적응형 병렬 추론(APR) 개요

[그림 2 - 적응형 병렬 추론(APR) 개요]
  • 멀티스레딩 추론 시스템
    • 부모 스레드는 spawn(msgs) 연산을 통해 여러 자식 스레드를 생성함.
    • 각 자식 스레드는 고유한 컨텍스트를 받아 독립적으로 추론을 수행하며, 동일한 언어 모델을 동시에 공유해 병렬 실행함.
    • 자식 스레드가 작업을 완료하면 join(msg) 연산을 통해 부모 스레드에 결과를 반환하는데, 이때 가장 중요한 정보만 선택적으로 전달함
    • 중간 추론 경로(search trace)를 자식 스레드 내부에 국한시켜 불필요한 토큰 사용을 대폭 줄이는 효과가 있음.
  • 훈련 방식은 2단계 접근법
    • 자동 생성된 데모(demonstration)를 활용한 지도 학습을 수행함.
    • 이 데모는 깊이 우선(depth-first)너비 우선(breadth-first) 탐색 전략을 혼합한 하이브리드 탐색 패턴을 포함
    • 기호 기반 솔버(symbolic solver)는 병렬화를 포함한 데모를 생성하며, 검색 과정을 여러 하위 구성으로 분해해 훈련 및 추론 중 컨텍스트 윈도우 병목 현상을 방지함
  • GRPO(Gradient-based Policy Optimization)를 이용한 엔드 투 엔드 강화학습
    • 모델은 언제, 얼마나 넓게 자식 스레드를 생성할지 전략적으로 판단하는 법을 학습함.
    • 모델은 반복적으로 추론 경로를 샘플링하고, 그 정확도를 평가하며, 그에 따라 파라미터를 조정함
    • 결과적으로 모델은 병렬 탐색과 컨텍스트 제약 간의 균형을 최적화하여, 계산 효율성과 추론 성능을 모두 향상시킴

6. 적응형 병렬 추론의 성능 평가 및 실험 결과

그림 3 - Countdown 과제를 해결하는 APR 기법의 예시 추론 경로

[그림 3 - Countdown 과제를 해결하는 APR 기법의 예시 추론 경로]
  • Llama2 아키텍처 기반의 2억 2천 8백만(228M) 파라미터를 가진 표준 디코더 전용 언어 모델을 사용하여, 직렬 체인 오브 싱킹(serialized chain-of-thought) 및 Self-Consistency 방법과 비교 평가함.
  • 4,096 토큰의 컨텍스트 윈도우를 지원하며, 모든 모델은 기호 기반 솔버(symbolic solver)로부터 생성된 50만 개의 추론 경로를 사용한 지도학습으로 초기화되었음
  • 조건부 계산 예산(budget constraint) 전략 : SoS+ 모델에는 컨텍스트 윈도우 조건(context-window conditioning), APR 모델에는 스레드 수 조건(thread count conditioning)
  • 추론은 SGLang 프레임워크를 기반으로 진행되었으며, 이 프레임워크는 지속적 배치 처리(continuous batching)Radix Attention을 지원하여 APR의 효율적 구현했음

그림 4 - SOS+와 APR의 확장 스케일링 비교

[그림 4 - SOS+와 APR의 확장 스케일링 비교]
  • 실험 결과 요약
    • APR은 전반적으로 직렬 방식보다 지속적으로 우수한 성능을 보임.
    • 계산량이 낮은 초기 구간에서는 병렬화 오버헤드로 인해 APR의 성능이 다소 떨어지지만, 계산량이 증가할수록 성능 격차가 급격히 확대됨.
      • 예: 토큰 20,000 사용 시 APR은 SoS+ 대비 13.5% 향상, 계산량은 57.4% 더 적게 사용하면서 SoS+의 pass@8 정확도를 초과함.
    • 컨텍스트 윈도우 확장 시, APR은 추론을 병렬 스레드로 분산시킴으로써 컨텍스트 사용 효율이 매우 높음.
      • 예: 스레드 10개 사용 시 4k 토큰 제한 내에서 약 20% 더 높은 정확도 달성. 이는 하나의 긴 추론 경로를 컨텍스트에 모두 포함시키는 기존 방식보다 효율적임을 시사.

그림 5 - 강화학습(RL) 전후의 모델 성능 및 통계 비교

[그림 5 - 강화학습(RL) 전후의 모델 성능 및 통계 비교]
  • 강화학습(RL)에 의한 최적화 효과
    • 엔드 투 엔드 강화학습(GRPO)을 통해 APR 모델의 정확도는 75.5% → 83.4%로 크게 향상됨.
    • RL로 최적화된 모델은 기존 모델보다 출력 시퀀스 길이: 22.1% 증가 및 자식 스레드 수: 34.4% 증가
    • 이로 인해 “Countdown” 같은 과제에서는 깊은 탐색보다는 넓은 탐색 패턴을 선호하게 됨.
      • 이는 모델이 자율적으로 최적의 검색 전략(search strategy)을 학습할 수 있음을 보여줌.

그림 6 - APR과 SOS+ 간 효율성 성능 비교

[그림 6 - APR과 SOS+ 간 효율성 성능 비교]
  • 이론적 및 실전 효율성 모두 입증
    • 순차 토큰 사용량 측정 결과:
      • APR은 2,048 토큰을 넘기지 않고도 정확도를 크게 향상시킴.
      • 대부분의 경우 2,500 토큰을 넘기지 않음.
      • 반면 SoS+는 3,000 토큰에 가까운 상황에서도 성능 향상은 미미함.
    • 실제 지연(latency) 테스트:
      • NVIDIA RTX A6000 GPU 8개를 사용한 서버 환경에서 측정.
      • APR은 샘플당 5,000ms 지연 시간에서 75% 정확도 달성,
        • 이는 SoS+의 57% 대비 18%p 향상된 수치.
      • APR이 하드웨어 병렬화 성능을 효과적으로 활용함을 보여줌.

그림 7 - 강화학습(RL)은 SOS+(2.7%)보다 APR(7.9%)에 더 큰 성능 향상을 제공

[그림 7 - 강화학습(RL)은 SOS+(2.7%)보다 APR(7.9%)에 더 큰 성능 향상을 제공]

7. 결론

  • 언어 모델의 추론 능력을 한 단계 진보시킨 기술로, 동적 부모-자식 스레딩 구조를 통해 계산을 직렬과 병렬 경로로 유연하게 분산시킴.
  • 지도 학습과 강화 학습을 결합함으로써 사람이 설계한 추론 구조 없이도, 모델이 스스로 최적의 병렬화 전략을 개발할 수 있게 함.
  • Countdown 과제를 통한 실험 결과: 제한된 컨텍스트 내에서의 높은 성능, 계산량 증가에 따른 우수한 확장성, 동일한 지연 조건에서도 기존 방식 대비 높은 성공률
  • 이러한 성과는 복잡한 문제 해결 과제에서 추론 과정을 동적으로 구조화하는 시스템이 얼마나 효율적이고 확장 가능할 수 있는지를 보여주는 대표 사례임.

참고: https://www.marktechpost.com/2025/05/02/llms-can-now-reason-in-parallel-uc-berkeley-and-ucsf-researchers-introduce-adaptive-parallel-reasoning-to-scale-inference-efficiently-without-exceeding-context-windows/

댓글남기기