Fine-tuning CLIP for MS-COCO Retrieval

In this section, we present an example usage and some empirical guides of fine-tuning CLIP for image-text retrieval. We aim to improve the retrieval performance based on the strong zero-shot retrieval ability (see our evaluation report) of CLIP by fine-tuning CLIP on MS COCO Captions training set (118k images) with the InfoNCE loss. Contents and key findings of this section are listed as follows:

  • Fine-tuning CLIP on MS COCO training set improves the retrieval mean recall by +15% compared to raw zero-shot retrieval.

  • Proper hyper-parameters can bring at least +1% improvement.

  • Scale up batch size by partially freeze CLIP weights brings +1% improvement.

  • Compared to the zero-shot retrieval mean recall=58.39% of RN50 CLIP, at last we achieve 76.02% mean recall (17.63% improvement) by fine-tuning it with a 8x2080ti machine.

Getting Started: Naive Fine-tuning Baseline

First, assume that you have already created an environment with required dependencies, prepared csv datasets for pre-training and downstream evaluations. Then you can activate the environment and modify the PYTHONPATH variable, such that modules can be imported successfully.

conda activate ITRA
cd path/to/ITRA/
export PYTHONPATH="$PYTHONPATH:$PWD/itra"

Then we can start to fine-tune a CLIP on MS-COCO captions 2017 training set (118k images). The results should be compared with the paper-with-code leaderboard. Our baseline setting are listed as follows, we use a single-node machine with 8 NVIDIA GeForce 2080ti GPUs for training, one training epoch takes about 3.5 minutes.

  • backbone: ResNet50

  • batch_size: 32x8=256

  • dataset_size: 118287

  • epochs: 10

  • lr: 1e-05

  • opt: adamw

  • use_bn_sync: False

  • warmup: 100

  • weight_decay: 0.5

Training Command
torchrun --nproc_per_node 8 -m training.main \
    --train-data 'mscoco_captions' --retrieval-data 'mscoco_captions' \
    --retrieval-frequency 1 --datasets-dir '/data/Datasets' \
    --epochs 10 --save-frequency 0 --batch-size 32 --workers 2 \
    --lr 1e-5 --warmup 100 --weight_decay 0.5 --max-grad-norm 5 \
    --image-model 'RN50' --image-model-builder 'openclip' --text-model 'RN50' --text-model-builder 'openclip'\
    --pretrained-image-model --pretrained-text-model \
    --loss 'InfoNCE' \
    --report-to tensorboard --logs 'logs/MSCOCO-RN50'  --name '10ep-bs256-lr1e-5-wd0.5'

Under this configuration, fine-tuning significantly improves the retrieval performance (58.39→73.98, +15.59).

Type Model # Params (M) I2T R@1 I2T R@5I I2T R@10 T2I R@1 T2I R@5 T2I R@10 Mean Recall
Two-stream Zero-shot CLIP RN50 102.01 48.06 73.88 83.02 28.31 52.96 64.1 58.39
Two-stream 👉 Fine-tuned CLIP RN50 102.01 64.84 86.62 92.3 44.99 72.76 82.34 73.98
Two-stream FLIP (ViT-L-14) 427.94 78.9 94.4 97.4 61.2 84.3 90.6 84.5
Two-stream Florence (CoSwin-H) 637 81.8 95.2 63.2 85.7
Single-stream BLIP (large) 220 80.6 95.2 97.6 63.1 85.3 91.1 85.5
Single-stream PTP-BLIP (large) 220 84.2 79.3 98.8 68.8 89.5 94.2 88.8

Note

👆 Here Florence and PTP-BLIP are respectively the two-stream and single-stream SoTA retrieval methods at paper-with-code leaderboard by 2022.12.


Tuning Hyper-parameters

1. Learning Rate. We vary the learning rate from 5e-6 to 1e-4, and find that 1e-5 and 2e-5 are good for a batch size of 256. This results confirms the observations in this paper, where the authors showed that good ImageNet fine-tuning of CILP ViT-B-16 needs a quite small learning rate (2e-5 and 3e-5 for a batch size of 2048).

Learning Rate lr5e-6 lr1e-5 lr2e-5 lr3e-5 lr5e-5 lr1e-4
Mean Recall 72.91 **73.98 ** **73.97 ** 73.32 72.46 69.34

2. Weight Decay. The author of SLIP paper observed that a larger weight decay (0.5) is beneficial for CLIP. Our experiments also showed that CLIP can also handle a very large value of weight decay (i.e., 2.50). Here the training data have 118k samples, and we believe that such property can further benefit CLIP fine-tuning when the data is limited. Our results, as shown in the following table, show that CLIP is pretty robust to weight decay changes: when vary the value from 0.01 to 2.50, the performance changes in a range of only +- 0.43.

Weight Decay 2.50 2.25 2.00 1.75 1.50 1.25 1.00 0.75 0.50 0.10 0.05 0.01
Mean Recall 74.07 73.94 73.87 73.84 73.94 73.94 74.05 73.87 73.98 73.93 73.64 73.79

3. Training Length. Similar to the experiments in FLIP, our experiments showed that scaling training epochs cannot lead to further performance improvement. Only 5 or 10 epochs are not sufficient, but 15-20 epochs seems already reached the saturation.

Epochs 5 10 15 20 30
Learning Rate=1e-5 72.66 73.98 74.43 74.45 73.96
Learning Rate=2e-5 72.86 73.97 74.28 74.02 74.03

4. Batch Size. It is well known that batch size has a crucial impact for contrastive learning methods. We confirm this point by varying batch size from 32 to 800 (the maximum allowed batch size for ResNet-50 CLIP on a 8x2080ti machine) while changing learning rate according to liner scaling rule. It shows that scaling down batch size leads to significant performance drop:

BatchSize 800 512 256 128 64 32
Learning Rate 3.125E-05 2.00E-05 1.00E-05 5.00E-06 2.50E-06 1.25E-06
Mean Recall 74.89 74.85 73.98 72.14 69.24 65.04

5. ✨ Improved Naive Baseline with Better Hyper-parameters. Combining all the above hyper-parameter sweep observations together, we increase the mean recall of naive fine-tuning baseline from 73.98 to 75.04.

Baseline Hyper-parameters ✨ Improved Hyper-parameters
backbone: ResNet50 ResNet50
batch_size: 32x8=256 100x8=800
epochs: 10 15
lr: 1e-05 3.125e-05
weight_decay: 0.5 1.0
Training Command
torchrun --nproc_per_node 8 -m training.main \
    --train-data 'mscoco_captions' --retrieval-data 'mscoco_captions' \
    --retrieval-frequency 1 --datasets-dir '/data/Datasets' \
    --epochs 15 --save-frequency 0 --batch-size 100 --workers 2 \
    --lr 3125e-8 --warmup 100 --weight_decay 1.0 --max-grad-norm 5 \
    --image-model 'RN50' --image-model-builder 'openclip' --text-model 'RN50' --text-model-builder 'openclip'\
    --pretrained-image-model --pretrained-text-model \
    --loss 'InfoNCE' \
    --report-to tensorboard --logs 'logs/MSCOCO-RN50'  --name '15ep-bs800-lr3125e-8-wd1.0'

Results:

Model I2T R@1 I2T R@5I I2T R@10 T2I R@1 T2I R@5 T2I R@10 Mean Recall
Baseline 64.84 86.62 92.30 44.99 72.76 82.34 73.98
Improved Baseline 65.34 87.44 92.84 46.70 74.45 83.47 75.04

Scaling up Batch Size by Partially Freeze Weights

Fine-tuning Streategy Image Params Text Params Total Trainable Params (M) % I2T R@1 I2T R@5I I2TR@10 T2I R@1 T2I R@5 T2I R@10 Mean Recall
zero-shot evaluation - - 0 0.0% 48.06 73.88 83.02 28.31 52.96 64.10 58.39
lock CLIP and add linear projection heads linear head linear head 2.1 2.1% 47.24 75.06 84.82 32.91 61.21 72.83 62.34
lock CLIP and add MLP projection heads MLP head MLP head 16.79 16.5% 53.12 79.86 87.76 37.46 65.63 76.41 66.71
lock image tune text - All 63.69 62.4% 62.12 85.12 91.46 42.52 70.34 80.31 71.98
lock text tune image All - 38.32 37.6% 59.78 84.10 90.86 43.57 71.02 80.76 71.68
naïve fine-tuning (improved baseline) All All 102.01 100.0% 65.34 87.44 92.84 46.70 74.45 83.47 75.04
  • lock image and partially fine-tune text

Text Params projection+ln_final 11 10,11 8~11 6~11 4~11 2~11 0~11 All
Total Trainable Params (M) 0.53 3.68 6.83 13.13 19.44 25.74 32.05 38.35 63.69
% 0.5% 3.6% 6.7% 12.9% 19.1% 25.2% 31.4% 37.6% 62.4%
Mean Recall 67.23 69.15 70.27 71.36 71.79 71.96 72.00 72.18 71.98
  • lock text and partially fine-tune image

Image Params attnpool attnpool,layer4 attnpool,layer4,3 attnpool,layer4,3,2 attnpool,layer4,3,2,1 All
Text Params - - - - - -
Total Trainable Params (M) 14.79 29.75 36.85 38.07 38.29 38.32
% 14.5% 29.2% 36.1% 37.3% 37.5% 37.6%
Mean Recall 71.33 72.49 72.23 71.89 71.82 71.68
  • Scale up batchsize

Fine-tuning Streategy Image Params Text Params Total Trainable Params (M) % I2T R@1 I2T R@5I I2TR@10 T2I R@1 T2I R@5 T2I R@10 Mean Recall
naïve fine-tuning (improved baseline) bs-800-lr3.125e-5 All All 102.01 100.0% 65.34 87.44 92.84 46.70 74.45 83.47 75.04
bs800-lr3.125e-5 attnpool,layer4 0~11 68.11 66.8% 66.10 87.60 93.56 47.61 75.17 84.18 75.70
bs1792-lr7e-5 attnpool,layer4 0~11 68.11 66.8% 65.95 88.30 93.66 48.08 75.71 84.42 76.02
Training Command
torchrun --nproc_per_node 8 -m training.main \
    --train-data 'mscoco_captions' --retrieval-data 'mscoco_captions' \
    --retrieval-frequency 1 --datasets-dir '/data/Datasets' \
    --epochs 15 --save-frequency 15 --batch-size 224 --workers 4 \
    --lr 7e-5 --warmup 100 --weight_decay 1.0 --max-grad-norm 5 \
    --image-model 'RN50' --image-model-builder 'openclip' --text-model 'RN50' --text-model-builder 'openclip'\
    --pretrained-image-model --pretrained-text-model --lock-image-model \
    --lock-text-partial 'positional_embedding,token_embedding' \
    --lock-image-partial '!attnpool,!layer4' \
    --loss 'InfoNCE' \
    --report-to tensorboard --logs 'logs/MSCOCO-RN50-partial'  --name 'save-lock-image(!attnpool,!layer4)-lock-text(positional_embedding,token_embedding)-bs1792-lr7e-5'

More Tricks for Fine-tuning

Layer-wise Learning Rate Decay (LLDR)

--layer_decay_image 0.9 --layer_decay_text 1 \
for layer_decay_text in 1.0 0.95 0.9 0.85 0.8 0.75 0.7 0.65 0.6;
do
torchrun --nproc_per_node 8 -m training.main \
    --train-data 'mscoco_captions' --retrieval-data 'mscoco_captions' \
    --retrieval-frequency 1 --datasets-dir '/data/Datasets' \
    --epochs 15 --save-frequency 0 --batch-size 224 --workers 2 \
    --lr 7e-5 --warmup 100 --weight_decay 1.0 --max-grad-norm 5 \
    --image-model 'RN50' --image-model-builder 'openclip' --text-model 'RN50' --text-model-builder 'openclip'\
    --pretrained-image-model --pretrained-text-model --lock-image-model \
    --lock-text-partial 'positional_embedding,token_embedding' \
    --lock-image-partial '!attnpool,!layer4' \
    --loss 'InfoNCE' \
    --report-to tensorboard --logs 'logs/MSCOCO-RN50-LLDR'  --name 'layer_decay_text='$layer_decay_text \
    --layer_decay_text $layer_decay_text; 
done

Exponential Moving Average (EMA)

--model_ema --model_ema_decay 0.998 \
for model_ema_decay in 0.99999 0.9999 0.9995 0.999 0.995 0.99 0.95 0.9 0.8;
do
torchrun --nproc_per_node 8 -m training.main \
    --train-data 'mscoco_captions' --retrieval-data 'mscoco_captions' \
    --retrieval-frequency 1 --datasets-dir '/data/Datasets' \
    --epochs 15 --save-frequency 0 --batch-size 224 --workers 2 \
    --lr 7e-5 --warmup 100 --weight_decay 1.0 --max-grad-norm 5 \
    --image-model 'RN50' --image-model-builder 'openclip' --text-model 'RN50' --text-model-builder 'openclip'\
    --pretrained-image-model --pretrained-text-model --lock-image-model \
    --lock-text-partial 'positional_embedding,token_embedding' \
    --lock-image-partial '!attnpool,!layer4' \
    --loss 'InfoNCE' \
    --model_ema --model_ema_decay $model_ema_decay \
    --report-to tensorboard --logs 'logs/MSCOCO-RN50-EMA'  --name 'model_ema_decay='$model_ema_decay;
done

Wise-FT. Evaluate the model with weight space ensemble

Wise-FT

--eval-with-wise-ft 0.5 \
for alpha in 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 ;
do
python itra/training/main.py \
    --zeroshot-frequency 1 --retrieval-frequency 1 --retrieval-data 'mscoco_captions' --datasets-dir '/data/Datasets' \
    --image-model 'RN50' --image-model-builder 'openclip'  \
    --text-model 'RN50' --text-model-builder 'openclip'  \
    --pretrained-image-model --pretrained-text-model \
    --resume 'logs/MSCOCO-RN50-partial/save-lock-image(!attnpool,!layer4)-lock-text(positional_embedding,token_embedding)-bs1792-lr7e-5/checkpoints/epoch_15.pt' \
    --eval-with-wise-ft $alpha \
    --logs 'logs/MSCOCO-RN50-WiseFT'  --name 'zs+retrieval-WiseFT='$alpha;
done

rsicd retrieval

# 1x2080ti machine RSICD 对点
torchrun --nproc_per_node 8 -m training.main \
    --train-data '/data/Datasets/RSICD/csv/rsicd_train.csv' --images-dir '/data/Datasets/RSICD/RSICD_images/RSICD_images' \
    --csv-separator '\t' --csv-img-key 'filename' --csv-caption-key 'title' \
    --retrieval-data '/data/Datasets/RSICD/csv/rsicd_test.csv' --retrieval-images-dir '/data/Datasets/RSICD/RSICD_images/RSICD_images' \
    --retrieval-csv-separator '\t' --retrieval-csv-img-key 'filename' --retrieval-csv-caption-key 'title' \
    --retrieval-frequency 1  --datasets-dir '/data/Datasets' \
    --epochs 30 --save-frequency 0 --batch-size 16 --workers 2 \
    --lr 1e-6 --warmup 100 --weight_decay 0.5 --max-grad-norm 5 \
    --image-model 'ViT-L-14-336' --image-model-builder 'openclip' \
    --text-model 'ViT-L-14-336' --text-model-builder 'openclip' \
    --pretrained-image-model --pretrained-text-model \
    --lock-image-model --lock-text-model \
    --lock-image-partial '!ln_post,!resblocks.23,!resblocks.22,!resblocks.21,!resblocks.20,!resblocks.19,!resblocks.18' \
    --lock-text-partial '!text_projection,!ln_final,!resblocks.11,!resblocks.10,!resblocks.9' \
    --loss 'InfoNCE' --layer_decay_image 0.9 --layer_decay_text 0.9 \
    --report-to tensorboard --logs 'logs/RSICD-ViT-L-14'  --name '30ep-b128-lr1e-5-unlock-image-text-last0.75-lldr0.9'

python itra/training/main.py –config-yaml ‘logs/params.yml’ –name ‘custom-name’

python itra/training/main.py –episode-size 10000 –train-data ‘mscoco_captions’ –retrieval-data ‘mscoco_captions’ –retrieval-frequency 1 –datasets-dir ‘/data/Datasets’ –epochs 15 –save-frequency 0 –batch-size 100 –workers 2 –lr 1e-4 –warmup 100 –weight_decay 1.0 –max-grad-norm 5 –image-model ‘RN50’ –image-model-builder ‘openclip’ –text-model ‘RN50’ –text-model-builder ‘openclip’–pretrained-image-model –pretrained-text-model –lock-image-model –lock-text-model –loss ‘InfoNCE’ –prompt –n-prompt 4 –report-to tensorboard –logs ‘logs/test’ –name ‘coco-finetune-nprompt-4’