Kode untuk kertas CVPR 2023 "Bahasa dalam botol: Model bahasa memandu hambatan konsep untuk klasifikasi gambar yang dapat ditafsirkan"
Kami menjalankan percobaan kami menggunakan Python 3.9.13. Anda dapat menginstal paket yang diperlukan menggunakan:
conda create --name labo python=3.9.13
conda activate labo
pip install -r requirements.txt
Anda perlu memodifikasi kode sumber aprikot untuk menjalankan optimasi submodular. Lihat detailnya di sini.
cfg/ menyimpan file konfigurasi untuk semua percobaan, termasuk probe linier ( cfg/linear_probe ) dan labo ( cfg/asso_opt ). Anda dapat memodifikasi file konfigurasi untuk mengubah argumen sistem.
datasets/ Menyimpan data khusus dataset, termasuk images , splits , dan concepts . Silakan periksa datasets/DATASET.md untuk detailnya.
Catatan : Gambar setiap dataset tidak disediakan dalam repo ini; Anda perlu mengunduhnya dan menyimpan di folder yang sesuai: datasets/{dataset name}/images/ . Periksa datasets/DATASET.md untuk instruksi tentang mengunduh semua set data.
exp/ adalah direktori kerja dari percobaan. File konfigurasi dan pos pemeriksaan model akan disimpan di folder ini.
models/ Menyimpan Model:
models/linear_prob/linear_prob.pymodels/asso_opt/asso_opt.pymodels/select_concept/select_algo.py output/ : Kinerja akan disimpan ke dalam file .txt yang disimpan dalam output/ .
File Lainnya:
data.py dan data_lp.py adalah masing -masing dataloaders untuk probe labo dan linier.main.py adalah antarmuka untuk menjalankan semua percobaan, dan utils.py berisi fungsi preprocess dan fitur ekstraksi.linear probe.sh adalah file bash untuk menjalankan probe linier. labo_train.sh dan labo_test.sh adalah file bash untuk melatih dan menguji labo. Untuk mendapatkan kinerja probe linier, jalankan saja:
sh linear_probe.sh {DATASET} {SHOTS} {CLIP SIZE}
Misalnya, untuk dataset bunga 1-shot dengan encoder gambar vit-l/14, perintahnya adalah:
sh linear_probe.sh flower 1 ViT-L/14
Kode akan secara otomatis menyandikan gambar dan menjalankan pencarian hiperparameter pada regularisasi L2 menggunakan set dev. Kinerja validasi dan pengujian terbaik akan disimpan di output/linear_probe/{DATASET}.txt .
Untuk melatih labo, jalankan perintah berikut:
sh labo_train.sh {SHOTS} {DATASET}
Log pelatihan akan diunggah ke wandb . Anda mungkin perlu mengatur akun wandb Anda secara lokal. Setelah mencapai zaman maksimum, pos pemeriksaan dengan akurasi validasi tertinggi dan file konfigurasi yang sesuai akan disimpan ke exp/asso_opt/{DATASET}/{DATASET}_{SHOT}shot_fac/ .
Untuk mendapatkan kinerja pengujian, gunakan pos pemeriksaan model dan konfigurasi yang sesuai yang disimpan di exp/asso_opt/{DATASET}/{DATASET}_{SHOT}shot_fac/ dan run:
sh labo_test.sh {CONFIG_PATH} {CHECKPOINT_PATH}
Akurasi pengujian akan dicetak ke output/asso_opt/{DATASET}.txt .
Harap kutip kertas kami jika Anda merasa berguna!
@inproceedings{yang2023language,
title={Language in a bottle: Language model guided concept bottlenecks for interpretable image classification},
author={Yang, Yue and Panagopoulou, Artemis and Zhou, Shenghao and Jin, Daniel and Callison-Burch, Chris and Yatskar, Mark},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={19187--19197},
year={2023}
}