CVPR 2023紙的代碼“瓶中的語言:語言模型指導概念瓶頸用於可解釋的圖像分類”
我們使用Python 3.9.13進行實驗。您可以使用以下方式安裝所需的軟件包:
conda create --name labo python=3.9.13
conda activate labo
pip install -r requirements.txt
您需要修改杏子的源代碼以運行子模塊優化。請參閱此處的詳細信息。
cfg/保存所有實驗的配置文件,包括線性探針( cfg/linear_probe )和LABO( cfg/asso_opt )。您可以修改配置文件以更改系統參數。
datasets/存儲數據集特定的數據,包括images , splits和concepts 。請檢查datasets/DATASET.md以獲取詳細信息。
注意:此存儲庫中未提供每個數據集的圖像;您需要下載它們並存儲在相應的文件夾中: datasets/{dataset name}/images/ 。檢查datasets/DATASET.md以獲取有關下載所有數據集的說明。
exp/是實驗的工作目錄。配置文件和模型檢查點將保存在此文件夾中。
models/保存模型:
models/linear_prob/linear_prob.pymodels/asso_opt/asso_opt.pymodels/select_concept/select_algo.py output/ :性能將保存到存儲在output/中的.txt文件中。
其他文件:
data.py和data_lp.py分別是Labo和線性探測器的數據加載器。main.py是運行所有實驗的接口,並且utils.py包含預處理和特徵提取功能。linear probe.sh是運行線性探針的bash文件。 labo_train.sh和labo_test.sh是訓練和測試labo的bash文件。 要獲得線性探測性能,只需運行:
sh linear_probe.sh {DATASET} {SHOTS} {CLIP SIZE}
例如,對於使用VIT-L/14圖像編碼器的Flower DataSet 1-shot,該命令是:
sh linear_probe.sh flower 1 ViT-L/14
代碼將自動編碼圖像,並使用DEV集在L2正則化上運行超參數搜索。最佳驗證和測試性能將保存在output/linear_probe/{DATASET}.txt中。
要訓練Labo,請運行以下命令:
sh labo_train.sh {SHOTS} {DATASET}
培訓日誌將上傳到wandb 。您可能需要在本地設置wandb帳戶。達到最大時期後,具有最高驗證精度的檢查點將保存到exp/asso_opt/{DATASET}/{DATASET}_{SHOT}shot_fac/ 。
要獲得測試性能,請使用保存在exp/asso_opt/{DATASET}/{DATASET}_{SHOT}shot_fac/並運行:
sh labo_test.sh {CONFIG_PATH} {CHECKPOINT_PATH}
測試精度將打印到output/asso_opt/{DATASET}.txt 。
如果您覺得有用,請引用我們的論文!
@inproceedings{yang2023language,
title={Language in a bottle: Language model guided concept bottlenecks for interpretable image classification},
author={Yang, Yue and Panagopoulou, Artemis and Zhou, Shenghao and Jin, Daniel and Callison-Burch, Chris and Yatskar, Mark},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={19187--19197},
year={2023}
}