Flamingo: a Visual Language Model for Few-Shot Learning
Flamingo: a Visual Language Model for Few-Shot Learning
- google deepmind, NIPS 2022
- “Flamingo: a Visual Language Model for Few-Shot Learning”
💡 기존의 거대 모델들을 Frozen시킨 채, Gated Cross-Attention을 통해 시각 정보를 주입함으로써 학습하지 않은 새로운 visual task도 Few-shot 방식으로 가능하게 함
abstract
- 목적: 소수의 예시만으로 새로운 task에 빠르게 적응하는 능력을 가진 flamingo라는 vlm을 제시함
- 핵심 기능
- 이미 강력한 성능을 가진 vision encoder와 llm을 효과적으로 연결하는 구조
- 텍스트와 이미지/비디오가 임의로 섞인 시퀀스를 입력받아 처리할 수 있음
- 이미지/비디오를 모두 입력으로 받아들임
- 텍스트와 이미지가 혼합된 대규모 웹 코퍼스를 통해 학습 → in-context few-shot learning 능력 가짐
- 별도의 finetuning 없이 몇 개의 예제만 프롬프트로 제공해도, 기존에 많은 데이터로 학습한 모델보다 더 좋은 성능을 기록함
- vqa, 캡셔닝 등 여러 이미지/비디오 task를 하나의 모델로 수행할 수 있음
→ “기존의 llm처럼, 시각 정보가 포함된 문제도 몇 가지 예시만 보여주면 따로 재학습 시킬 필요 없이 즉각적으로 문제를 풀어내는 강력한 모델을 만들었다”
introduction
- 배경
- 하나의 짧은 instruction을 주고 새로운 task를 빠르게 배우는 능력이 지능의 핵심 요소임
- 기존 컴퓨터 비전 분야의 방법들의 한계
- fine-tuning의 비용 큼
- clip과 같은 contrastive 모델은 이미지-텍스트의 유사도를 측정할 수는 있지만, 직접 문장을 생성 x → 일부 task에서만 사용 가능하고 캡셔닝, vqa 등 개방형 task에 부적합함
- 언어를 생성하는 vlm은 데이터가 적을 때 성능이 좋지 않았음
- flamingo 제시
- llm이 가진 few-shot 학습 능력을 시각 영역으로 확장함
- 텍스트 중간중간에 이미지/비디오가 섞인 시퀀스를 입력받아서 다음에 올 텍스트를 예측하는 방식
- 사전학습한 비전 인코더와 llm을 그대로 사용함 + 비전 인코더와 llm을 연결하는 새로운 아키텍쳐를 추가해서 두 모델이 가지고 있는 지식을 보존하면서 잘 연결할 수 있도록 함
- perceiver 기반 구조: 고해상도 이미지나 비디오에서 가변적인 수의 피처가 추출되더라도, 고정된 수의 시각적 토큰으로 변환해서 효율적으로 처리함
- llm이 가진 few-shot 학습 능력을 시각 영역으로 확장함
- 학습 방법
- 정제된 annotated 데이터가 아니라 (일반 머신러닝 학습용 데이터), 텍스트-이미지가 혼합된 대규모 웹 데이터를 학습함
- 학습 후에는 별도의 tuning 없이 몇개의 예시를 보여주는 것만으로 새로운 visual task에 즉시 adapt함 (few-shot)
- contribution
- few-shot으로 다양한 멀티모달 task를 수행할 수 있는 flamingo 모델을 제시함
- 정량적으로 flamingo가 어떻게 few-shot을 통해서 여러 task에 adapt하는지 평가함
- 학습에 사용하지 않은 데이터로 few-shot 능력 측정
- 16개의 멀티모달 task에서 sota를 달성함
approach
- 핵심 구조
- perceiver resampler: 비전 인코더로부터 이미지/비디오의 시공간적 피처를 입력받아서, 고정된 개수의 visual 토큰으로 변환함
- cross-attention layers: frozen lm 레이어 사이에 삽입
- lm이 다음 토큰을 예측할 때 시각적 정보를 풍부하게 활용할 수 있도록 도움
flamingo의 수학적 모델링
- y_l: 현재 예측해야 할 l번째 언어 토큰
- y<l: 이전의 언어 토큰 집합,
- x≤l: 현재 토큰 y_l 앞에 잇는 이미지/비디오의 집합
- in-context few-shot 학습
- 텍스트/이미지(비디오)가 섞인 입력을 처리할 수 있기 때문에, gpt-3와 유사하게 몇 개의 새로운 task에 대한 예제를 보여주면 별도의 학습 없이도 새로운 task를 수행함
2.1. visual processing and the perceiver resampler
- vision encoder
- 사전학습된 NormalizerFree ResNet (NFNet-F6) 모델 사용
- image-text pair 데이터셋을 활용해서 contrastive learning 목적으로 사전학습
- 이미지 처리: 최종 output인 2d 피처 grid를 1d 시퀀스로 flatten
- 비디오 처리: 초당 1 프레임으로 샘플링해서 각 프레임을 독립적으로 인코딩 → 학습된 시간 임베딩을 추가해서 3d 시공간 피처 grid를 만듬 → 1d 시퀀스로 flatten
- perceiver resampler
- 입력 이미지/비디오의 특징 개수가 가변적이더라도, 항상 64개의 고정된 시각적 출력 토큰으로 압축함
- cross attention의 계산 복잡도를 크게 줄여줌
- 미리 정의된 개수(64개)의 latent input queries를 학습함
- 이 쿼리들은 트랜스포머에 입력되어 visual feature와 cross attention을 수행하게 됨
- query q: 미리 정의된 query 벡터
- key k, value v :비전 인코더의 output (가변적인 개수)
- → 입력이 100개의 피처든 1000개든 상관없이 늘 64개의 visual token이 생성됨
- 이 쿼리들은 트랜스포머에 입력되어 visual feature와 cross attention을 수행하게 됨
2.2. conditioning frozen language models on visual representations
- 어떻게 완성된 lm의 지식을 망가뜨리지 않으면서 시각 정보를 자연스럽게 주입할 것인가?
- 기존 트랜스포머 디코더 layer는 frozen
- gated xattn-dense 블록
- 기존 lm 계층 사이에 새롭게 삽입되는 계층으로, perceiver resampler가 만든 시각적 출력에 대해 cross attention을 수행함
- 스크래치부터 학습됨
- 새로운 layer를 추가하면 초기 학습 시 모델이 불안정해질 수 있음
- 이를 해결하기 위해 tanh-gating 방식을 도입함
- 새로 추가된 레이어의 output에 tanh(alpha)를 곱한 뒤, 이를 기존 레이어의 residual connection에 더함
- alpha는 학습 시작 시 0으로 시작 → tanh(0)=0이되어 초기에는 새로운 레이어의 영향력이 0임
- 학습 초기에는 기존의 lm과 동일하게 학습이 시작되므로, 학습의 안정성이 높아짐
- 이를 해결하기 위해 tanh-gating 방식을 도입함
- 크기별 버전
- lm의 크기에 따라 3가지 버전으로 나뉨
- lm으로 deepmind의 chinchilla 모델을 사용
- Flamingo-3B: 1.4B 파라미터 Chinchilla 기반
- Flamingo-9B: 7B 파라미터 Chinchilla 기반
- Flamingo-80B: 70B 파라미터 Chinchilla 기반 ← 뒤에서 flamingo라고 언급하면 이 모델을 뜻하는 것
- lm으로 deepmind의 chinchilla 모델을 사용
- 모델의 크기를 키울 때 gated xattn-dense 모듈만 커지고, vision encoder나 perceiver resampler는 동일한 사이즈
- lm의 크기에 따라 3가지 버전으로 나뉨
2.3. Multi-visual input support: per-image/video attention masking
- 텍스트를 생성할 때, 모델은 직전에 등장한 이미지의 시각적 토큰에만 직접적으로 cross attention함
- 모든 이전 이미지를 한번에 보는게 아니라, 현재 맥락과 가장 관련있는 이미지에 집중
- lm 내부의 self attention을 통해 이전의 모든 이미지에 대한 정보가 텍스트를 통해 간접적으로 유지되긴 함..
- 학습할 때에는 시퀀스 당 최대 5개의 이미지만 사용함
- 이 마스킹 방법 덕분에 추론 시에는 최대 32개의 이미지/비디오 쌍이 포함된 긴 시퀀스도 처리할 수 있음
- 실험적으로 모든 이전 이미지를 참조하는 것 보다 바로 직전 이미지에 집중하는 것이 더 효과적임을 확인함
2.4. Training on a mixture of vision and language datasets
- flamingo는 3가지 종류의 웹에서 가져온 데이터로 학습됨: 이미지-텍스트가 뒤섞인 데이터셋, image-text pair, video-text pair
- M3W (multimodal massiveweb): 혼합 데이터셋
- 약 4,300만 개의 웹페이지 html에서 텍스트와 이미지를 추출함
- document object model 구조를 분석하여 텍스트 대비 이미지의 정확한 위치를 파악함
- 데이터 구조화
- 이미지 위치에
태그 삽입 - 이미지 나타나기 전, 문서 끝에
(end of chunk) 특수 토큰을 삽입해서 정보의 단락을 구분함
- 이미지 위치에
- 데이터 구조화
- 연산 효율을 위해 각 문서에서 256개의 토큰을 랜덤 샘플링하고, 그 안에서 최대 5개의 이미지만 포함시킴
이미지/비디오-텍스트 쌍
데이터셋 규모 특징 ALIGN 18억 개 이미지와 대체 텍스트(alt-text) 쌍 LTIP 3억 1,200만 개 고품질의 긴 설명을 가진 이미지-텍스트 쌍 VTP 2,700만 개 평균 22초 길이의 짧은 비디오와 문장 설명 쌍 - m3w랑 동일한 형식으로 만들기 위해, 캡션 앞에
를 붙이고, 뒤에 를 추가함
- m3w랑 동일한 형식으로 만들기 위해, 캡션 앞에
학습 목표와 최적화
- D_m: m번째 데이터셋
- lambda m: 각 데이터 별 가중치 → 이게 성능을 결정하는 핵심 요소임
- 이전 텍스트와 이전 시각 정보가 주어졌을 때, 정답 l번째 토큰이 나올 확률을 최대화
- 모든 데이터셋의 gradient를 accumulate해서 업데이트함 → 데이터를 순차적으로 하나씩 학습하는 round-robin 방식보다 더 성능이 뛰어남
2.5. Task adaptation with few-shot in-context learning
- finetuning없이 어떻게 새로운 task에 적응하는가?
- 학습이 끝나면, gpt-3와 유사한 방식으로 멀티모달 프롬프트를 통해 새로운 visual task를 해결함
- (image, text) or (video, text)의 예시 pair 들 + 쿼리 이미지와 프롬프트
- 모델의 출력 방식에 따라 2가지 평가 모드 사용
- 개방형 평가 (open-ended): beam 서치 디코딩을 사용해서 텍스트를 자유롭게 생성
- 폐쇄형 평가 (close-ended): 모델의 log likelihood를 사용해서 미리 정해진 답변 후보들의 점수를 매기고 가장 높은 것을 선택함
- zero-shot 일반화 성능 평가
- 이미지가 없는 텍스트 전용 예시 2개를 프롬프트로 제공해서, 시각적 예시 없이도 모델이 작업을 수행할 수 있는 지를 측정함
### Experiments
- 벤치마크 구성
- dev 벤치마크 5개
- COCO, OKVQA, VQAv2, MSVDQA, VATEX
- 학습할 때도 DEV로 중간 평가함 - bias 있음
- 평가용 추가 벤치마크 11개
- 캡셔닝, 비디오 vqa, 시각적 대화 … 여러 task를 포함
- dev 벤치마크 5개
- few-shot 평가 방법
- 예시들을 프롬프트로 제공하고, query sample에 대해 평가함
- 모든 벤치마크에서 평가 하이퍼파라미터를 동일하게 유지함
- task 특성에 따라 4가지의 few-shot 평가 프롬프트를 사용
3.1 Few-shot learning on vision-language tasks
- flamingo는 16개의 모든 벤치마크에서 기존의 zero/few-shot sota를 큰 차이로 능가함
- finetuning을 한 sota와 비교했을 때, 비등했음
- 심지어 6개의 task에서는 finetuning한 결과보다 더 좋은 성능을 보임
gpt-3과 유사하게, 모델의 크기가 커질수록 few-shot 성능이 향상됨
- 학습할 때는 최대 이미지를 5개씩만 봤는데, 추론할 때는 최대 32개의 visual 입력으로부터 정보를 얻어 성능을 높일 수 있었음
- → 가변적인 수의 시각적 입력을 처리하는 flamingo 아키텍쳐의 유연성을 보여줌
3.2 Fine-tuning Flamingo as a pretrained vision-language model
- flamingo 모델을 각 task에 finetuning했을 때 성능 변화를 확인함
- short schedule, small learning rate
- vision encoder도 함께 학습
- 5개의 벤치마크에서 새로운 sota를 달성함 / few-shot 학습보다 더 좋은 성능
3.3 ablation study
- 학습 데이터
- 3가지 종류의 데이터를 적절히 섞는 것이 중요함
- gradient accumulation을 사용하는 것이 성능 향상
- tanh gating
- gated xattn dense
- cross attention frequency
- 모든 lm층에서 새로운 블록을 추가하는 것이 성능 가장 좋
- perceiver resampler
- vision encoder
- lm: 처음부터 학습시키거나 vs fine-tuning
- lm은 freeze하는 것이 베스트
Discussion
- 한계
- 사전학습된 lm을 기반으로 하기 때문에 환각 등 lm의 문제점을 그대로 물려받음
- clip 등 contrastive 모델은 텍스트-이미지 검색에 최적화 / 분류 성능 좋은데, flmaingo는 이에 비해 분류 능력이 떨어짐
- in context의 예시를 어떻게 구성하느냐에 따라 성능이 달라짐, 데이터가 많아질 수록 in context의 연산 비용이 급증함
- 결론
- 최소한의 데이터로 이미지/비디오 task를 수행할 수 있는 범용 vlm
- 전통적인 비전 벤치마크를 넘어서 사용자와 대화가 가능함
- 이는 visual encoder와 llm을 연결하는 것이 범용적인 visual 이해로 가는 중요한 단계임을 시사함
This post is licensed under CC BY 4.0 by the author.









