Post

Flamingo: a Visual Language Model for Few-Shot Learning

Flamingo: a Visual Language Model for Few-Shot Learning
  • google deepmind, NIPS 2022
    • “Flamingo: a Visual Language Model for Few-Shot Learning”

💡 기존의 거대 모델들을 Frozen시킨 채, Gated Cross-Attention을 통해 시각 정보를 주입함으로써 학습하지 않은 새로운 visual task도 Few-shot 방식으로 가능하게 함

abstract

  • 목적: 소수의 예시만으로 새로운 task에 빠르게 적응하는 능력을 가진 flamingo라는 vlm을 제시함
  • 핵심 기능
    • 이미 강력한 성능을 가진 vision encoder와 llm을 효과적으로 연결하는 구조
    • 텍스트와 이미지/비디오가 임의로 섞인 시퀀스를 입력받아 처리할 수 있음
    • 이미지/비디오를 모두 입력으로 받아들임
  • 텍스트와 이미지가 혼합된 대규모 웹 코퍼스를 통해 학습 → in-context few-shot learning 능력 가짐
  • 별도의 finetuning 없이 몇 개의 예제만 프롬프트로 제공해도, 기존에 많은 데이터로 학습한 모델보다 더 좋은 성능을 기록함
  • vqa, 캡셔닝 등 여러 이미지/비디오 task를 하나의 모델로 수행할 수 있음

→ “기존의 llm처럼, 시각 정보가 포함된 문제도 몇 가지 예시만 보여주면 따로 재학습 시킬 필요 없이 즉각적으로 문제를 풀어내는 강력한 모델을 만들었다”

introduction

  • 배경
    • 하나의 짧은 instruction을 주고 새로운 task를 빠르게 배우는 능력이 지능의 핵심 요소임
    • 기존 컴퓨터 비전 분야의 방법들의 한계
      • fine-tuning의 비용 큼
      • clip과 같은 contrastive 모델은 이미지-텍스트의 유사도를 측정할 수는 있지만, 직접 문장을 생성 x → 일부 task에서만 사용 가능하고 캡셔닝, vqa 등 개방형 task에 부적합함
      • 언어를 생성하는 vlm은 데이터가 적을 때 성능이 좋지 않았음
  • flamingo 제시
    • llm이 가진 few-shot 학습 능력을 시각 영역으로 확장함
      • 텍스트 중간중간에 이미지/비디오가 섞인 시퀀스를 입력받아서 다음에 올 텍스트를 예측하는 방식
      • 사전학습한 비전 인코더와 llm을 그대로 사용함 + 비전 인코더와 llm을 연결하는 새로운 아키텍쳐를 추가해서 두 모델이 가지고 있는 지식을 보존하면서 잘 연결할 수 있도록 함
      • perceiver 기반 구조: 고해상도 이미지나 비디오에서 가변적인 수의 피처가 추출되더라도, 고정된 수의 시각적 토큰으로 변환해서 효율적으로 처리함
  • 학습 방법
    • 정제된 annotated 데이터가 아니라 (일반 머신러닝 학습용 데이터), 텍스트-이미지가 혼합된 대규모 웹 데이터를 학습함
    • 학습 후에는 별도의 tuning 없이 몇개의 예시를 보여주는 것만으로 새로운 visual task에 즉시 adapt함 (few-shot)
  • contribution
    1. few-shot으로 다양한 멀티모달 task를 수행할 수 있는 flamingo 모델을 제시함
    2. 정량적으로 flamingo가 어떻게 few-shot을 통해서 여러 task에 adapt하는지 평가함
      • 학습에 사용하지 않은 데이터로 few-shot 능력 측정
    3. 16개의 멀티모달 task에서 sota를 달성함
      • 그 중 6개 task에서는 finetuning한 모델보다 훨씬 적은 데이터를 사용하고도 더 좋은 성능을 기록함
      • 32개의 shot을 주었고, task-specfic 학습 데이터를 1000배 덜 사용하였음
      • 충분한 데이터가 주어진 경우, vqav2, vatex 등 5개의 주요 벤치마크에서도 새로운 sota 기록

        image.png

image.png

approach

  • 핵심 구조
    • perceiver resampler: 비전 인코더로부터 이미지/비디오의 시공간적 피처를 입력받아서, 고정된 개수의 visual 토큰으로 변환함
    • cross-attention layers: frozen lm 레이어 사이에 삽입
      • lm이 다음 토큰을 예측할 때 시각적 정보를 풍부하게 활용할 수 있도록 도움
  • flamingo의 수학적 모델링

    image.png

    • y_l: 현재 예측해야 할 l번째 언어 토큰
    • y<l: 이전의 언어 토큰 집합,
    • x≤l: 현재 토큰 y_l 앞에 잇는 이미지/비디오의 집합
  • in-context few-shot 학습
    • 텍스트/이미지(비디오)가 섞인 입력을 처리할 수 있기 때문에, gpt-3와 유사하게 몇 개의 새로운 task에 대한 예제를 보여주면 별도의 학습 없이도 새로운 task를 수행함

image.png

2.1. visual processing and the perceiver resampler

  • vision encoder
    • 사전학습된 NormalizerFree ResNet (NFNet-F6) 모델 사용
    • image-text pair 데이터셋을 활용해서 contrastive learning 목적으로 사전학습
    • 이미지 처리: 최종 output인 2d 피처 grid를 1d 시퀀스로 flatten
    • 비디오 처리: 초당 1 프레임으로 샘플링해서 각 프레임을 독립적으로 인코딩 → 학습된 시간 임베딩을 추가해서 3d 시공간 피처 grid를 만듬 → 1d 시퀀스로 flatten
  • perceiver resampler
    • 입력 이미지/비디오의 특징 개수가 가변적이더라도, 항상 64개의 고정된 시각적 출력 토큰으로 압축함
    • cross attention의 계산 복잡도를 크게 줄여줌
    • 미리 정의된 개수(64개)의 latent input queries를 학습함
      • 이 쿼리들은 트랜스포머에 입력되어 visual feature와 cross attention을 수행하게 됨
        • query q: 미리 정의된 query 벡터
        • key k, value v :비전 인코더의 output (가변적인 개수)
        • → 입력이 100개의 피처든 1000개든 상관없이 늘 64개의 visual token이 생성됨

2.2. conditioning frozen language models on visual representations

image.png

  • 어떻게 완성된 lm의 지식을 망가뜨리지 않으면서 시각 정보를 자연스럽게 주입할 것인가?
  • 기존 트랜스포머 디코더 layer는 frozen
  • gated xattn-dense 블록
    • 기존 lm 계층 사이에 새롭게 삽입되는 계층으로, perceiver resampler가 만든 시각적 출력에 대해 cross attention을 수행함
    • 스크래치부터 학습됨
  • 새로운 layer를 추가하면 초기 학습 시 모델이 불안정해질 수 있음
    • 이를 해결하기 위해 tanh-gating 방식을 도입함
      • 새로 추가된 레이어의 output에 tanh(alpha)를 곱한 뒤, 이를 기존 레이어의 residual connection에 더함
      • alpha는 학습 시작 시 0으로 시작 → tanh(0)=0이되어 초기에는 새로운 레이어의 영향력이 0임
      • 학습 초기에는 기존의 lm과 동일하게 학습이 시작되므로, 학습의 안정성이 높아짐
  • 크기별 버전
    • lm의 크기에 따라 3가지 버전으로 나뉨
      • lm으로 deepmind의 chinchilla 모델을 사용
        • Flamingo-3B: 1.4B 파라미터 Chinchilla 기반
        • Flamingo-9B: 7B 파라미터 Chinchilla 기반
        • Flamingo-80B: 70B 파라미터 Chinchilla 기반 ← 뒤에서 flamingo라고 언급하면 이 모델을 뜻하는 것
    • 모델의 크기를 키울 때 gated xattn-dense 모듈만 커지고, vision encoder나 perceiver resampler는 동일한 사이즈

2.3. Multi-visual input support: per-image/video attention masking

  • 텍스트를 생성할 때, 모델은 직전에 등장한 이미지의 시각적 토큰에만 직접적으로 cross attention함
    • 모든 이전 이미지를 한번에 보는게 아니라, 현재 맥락과 가장 관련있는 이미지에 집중
  • lm 내부의 self attention을 통해 이전의 모든 이미지에 대한 정보가 텍스트를 통해 간접적으로 유지되긴 함..
  • 학습할 때에는 시퀀스 당 최대 5개의 이미지만 사용함
  • 이 마스킹 방법 덕분에 추론 시에는 최대 32개의 이미지/비디오 쌍이 포함된 긴 시퀀스도 처리할 수 있음
  • 실험적으로 모든 이전 이미지를 참조하는 것 보다 바로 직전 이미지에 집중하는 것이 더 효과적임을 확인함

2.4. Training on a mixture of vision and language datasets

  • flamingo는 3가지 종류의 웹에서 가져온 데이터로 학습됨: 이미지-텍스트가 뒤섞인 데이터셋, image-text pair, video-text pair
  • M3W (multimodal massiveweb): 혼합 데이터셋
    • 약 4,300만 개의 웹페이지 html에서 텍스트와 이미지를 추출함
    • document object model 구조를 분석하여 텍스트 대비 이미지의 정확한 위치를 파악함
      • 데이터 구조화
        • 이미지 위치에 태그 삽입
        • 이미지 나타나기 전, 문서 끝에 (end of chunk) 특수 토큰을 삽입해서 정보의 단락을 구분함
    • 연산 효율을 위해 각 문서에서 256개의 토큰을 랜덤 샘플링하고, 그 안에서 최대 5개의 이미지만 포함시킴
  • 이미지/비디오-텍스트 쌍

    데이터셋규모특징
    ALIGN18억 개이미지와 대체 텍스트(alt-text) 쌍
    LTIP3억 1,200만 개고품질의 긴 설명을 가진 이미지-텍스트 쌍
    VTP2,700만 개평균 22초 길이의 짧은 비디오와 문장 설명 쌍
    • m3w랑 동일한 형식으로 만들기 위해, 캡션 앞에 를 붙이고, 뒤에 를 추가함
  • 학습 목표와 최적화

    image.png

    • D_m: m번째 데이터셋
    • lambda m: 각 데이터 별 가중치 → 이게 성능을 결정하는 핵심 요소임
    • 이전 텍스트와 이전 시각 정보가 주어졌을 때, 정답 l번째 토큰이 나올 확률을 최대화
    • 모든 데이터셋의 gradient를 accumulate해서 업데이트함 → 데이터를 순차적으로 하나씩 학습하는 round-robin 방식보다 더 성능이 뛰어남

2.5. Task adaptation with few-shot in-context learning

  • finetuning없이 어떻게 새로운 task에 적응하는가?
  • 학습이 끝나면, gpt-3와 유사한 방식으로 멀티모달 프롬프트를 통해 새로운 visual task를 해결함
    • (image, text) or (video, text)의 예시 pair 들 + 쿼리 이미지와 프롬프트
  • 모델의 출력 방식에 따라 2가지 평가 모드 사용
    • 개방형 평가 (open-ended): beam 서치 디코딩을 사용해서 텍스트를 자유롭게 생성
    • 폐쇄형 평가 (close-ended): 모델의 log likelihood를 사용해서 미리 정해진 답변 후보들의 점수를 매기고 가장 높은 것을 선택함
  • zero-shot 일반화 성능 평가
    • 이미지가 없는 텍스트 전용 예시 2개를 프롬프트로 제공해서, 시각적 예시 없이도 모델이 작업을 수행할 수 있는 지를 측정함

    ### Experiments

    • 벤치마크 구성
      • dev 벤치마크 5개
        • COCO, OKVQA, VQAv2, MSVDQA, VATEX
        • 학습할 때도 DEV로 중간 평가함 - bias 있음
      • 평가용 추가 벤치마크 11개
        • 캡셔닝, 비디오 vqa, 시각적 대화 … 여러 task를 포함
    • few-shot 평가 방법
      • 예시들을 프롬프트로 제공하고, query sample에 대해 평가함
      • 모든 벤치마크에서 평가 하이퍼파라미터를 동일하게 유지함
      • task 특성에 따라 4가지의 few-shot 평가 프롬프트를 사용

image.png

3.1 Few-shot learning on vision-language tasks

  • flamingo는 16개의 모든 벤치마크에서 기존의 zero/few-shot sota를 큰 차이로 능가함
  • finetuning을 한 sota와 비교했을 때, 비등했음
    • 심지어 6개의 task에서는 finetuning한 결과보다 더 좋은 성능을 보임
  • gpt-3과 유사하게, 모델의 크기가 커질수록 few-shot 성능이 향상

    image.png

  • 학습할 때는 최대 이미지를 5개씩만 봤는데, 추론할 때는 최대 32개의 visual 입력으로부터 정보를 얻어 성능을 높일 수 있었음
    • 가변적인 수의 시각적 입력을 처리하는 flamingo 아키텍쳐의 유연성을 보여줌

3.2 Fine-tuning Flamingo as a pretrained vision-language model

  • flamingo 모델을 각 task에 finetuning했을 때 성능 변화를 확인함
  • short schedule, small learning rate
  • vision encoder도 함께 학습
  • 5개의 벤치마크에서 새로운 sota를 달성함 / few-shot 학습보다 더 좋은 성능

image.png

3.3 ablation study

image.png

  • 학습 데이터
    • 3가지 종류의 데이터를 적절히 섞는 것이 중요함
  • gradient accumulation을 사용하는 것이 성능 향상
  • tanh gating
  • gated xattn dense
  • cross attention frequency
    • 모든 lm층에서 새로운 블록을 추가하는 것이 성능 가장 좋
  • perceiver resampler
  • vision encoder
  • lm: 처음부터 학습시키거나 vs fine-tuning
    • lm은 freeze하는 것이 베스트

Discussion

  • 한계
    • 사전학습된 lm을 기반으로 하기 때문에 환각 등 lm의 문제점을 그대로 물려받음
    • clip 등 contrastive 모델은 텍스트-이미지 검색에 최적화 / 분류 성능 좋은데, flmaingo는 이에 비해 분류 능력이 떨어짐
    • in context의 예시를 어떻게 구성하느냐에 따라 성능이 달라짐, 데이터가 많아질 수록 in context의 연산 비용이 급증함
  • 결론
    • 최소한의 데이터로 이미지/비디오 task를 수행할 수 있는 범용 vlm
    • 전통적인 비전 벤치마크를 넘어서 사용자와 대화가 가능함
    • 이는 visual encoder와 llm을 연결하는 것이 범용적인 visual 이해로 가는 중요한 단계임을 시사함
This post is licensed under CC BY 4.0 by the author.