Metadata-Version: 2.1
Name: dataset_interfaces
Version: 0.1.0
Summary: Dataset Interfaces
Home-page: https://github.com/MadryLab/dataset-interfaces
Author: MadryLab
Author-email: jvendrow@mit.edu
License: MIT
Description-Content-Type: text/markdown
License-File: LICENSE

# Dataset Interfaces

This repository contains the code for our recent work:

**Dataset Interfaces: Diagnosing Model Failures Using Controllable Counterfactual Generation** <br>
*Joshua Vendrow\*, Saachi Jain\*, Logan Engstrom, Aleksander Madry* <br>
Paper: [https://arxiv.org/abs/2302.07865](https://arxiv.org/abs/2302.07865) <br>
Blog post: [https://gradientscience.org/dataset-interfaces/](https://gradientscience.org/dataset-interfaces/)

<p>
<img src="dogs.png" width="1000" >
</p>

## Getting started
Install using pip, or clone our repository.
```
pip install dataset-interfaces
```

**Example:** For a walkthrough of codebase, check out our [example notebook](notebooks/Example.ipynb). This notebook shows how to
construct a dataset interface for a subset of ImageNet and generate counterfactual examples. 

Before running `run_textual_inverson`, initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

## Constructing a Dataset Interface
Constructing a dataset interface consists or learning a *class token* for each class in a datset, which can then be included in textual prompts. 

To learn a single token, we use the following function:
```python
from dataset_interfaces import run_textual_inversion

embed = run_textual_inversion (
    train_path=train_path,  # path to directory with training set for a single class
    token=token,            # text to use for new token, e.g "<plate>"
    class_name=class_name,  # natrual language class description, e.g., "plate"
)
```

Once all the class tokens are learned, we can create a custom tokenizer and text encoder with these tokens:

```python
import inference_utils as infer_utils

infer_utils.create_encoder (
    embeds=embeds,             # list of learned embeddings (from the code block above)
    tokens=tokens,             # list of token strings
    class_names=class_names,   # list of natural language class descriptions
    encoder_root=encoder_root  # path where to store the tokenizer and encoder
)
```

## Generating Counterfactual Examples

We can now generate counterfactual examples by incorporating our learned tokens in textual prompts. The ``generate`` function generates images for a specific class in the dataset (indexed in the order that classes are passed when constructing the encoder). When specifying the text prompt, "<TOKEN>" acts as a placeholder for the class token.
```python
from dataset_interfaces import generate

generate (
    encoder_root=encoder_root,
    c=c,                                          # index of a specific class
    prompts="a photo of a <TOKEN> in the grass",  # can be a single prompt or a list of prompts
    num_samples=10, 
    random_seed=0                                 # no seed by default
)
```

## CLIP Metric

To directly evaluate the quality of the generated image, we use *CLIP similarity* to quantify the presence of the object of interest and desired distribution shift in the image.

We can measure CLIP similarity between a set of generated images and a given caption as follows:

```
sim_class = infer_utils.clip_similarity(imgs, "a photo of a dog")
sim_shift = infer_utils.clip_similarity(imgs, "a photo in the grass")
```


## ImageNet* Benchmark
Our benchmark for the ImageNet dataset consists of two components: our 1,000 learned class tokens for ImageNet, and the images generated by these tokens in 23 distribution shifts. 

### ImageNet* Tokens

The 1,000 learned tokens are avaiable on [HuggingFace](https://huggingface.co/datasets/madrylab/imagenet-star-tokens) and can be downloaded with:
```
wget https://huggingface.co/datasets/madrylab/imagenet-star-tokens/resolve/main/tokens.zip
```
To generate images with these tokens, we first create a text encoder with the tokens, which we use to seamlessly integrate the tokens in text prompts:

```python
token_path = "./tokens". # path to the tokens from HuggingFace
infer_utils.create_imagenet_star_encoder(token_path, encoder_root="./encoder_root_imagenet")
```
Now, we can generate counterfactual examples of ImageNet from a textual prompt (See the [example notebook](notebooks/Example.ipynb) for a walk-through):
```python
from dataset_interfaces import generate

encoder_root = "./encoder_root_imagenet"
c = 207  # the class for golden retriever
prompt = "a photo of a <TOKEN> wearing a hat"
generate(encoder_root, c, prompt, num_samples=10)
```

### ImageNet* Images
Our benchmark contains images in 23 distribution shifts, with 50k images per shift (50 per class for 1000 classes). These images are also available on [HuggingFace](https://huggingface.co/datasets/madrylab/imagenet-star). In this repo we also provide masks for each distribution shift indicating which images we filter out with our CLIP metrics, at `masks.npy`.

We provide a wrapper on top `torchvision.datasets.ImageFolder` to construct a dataset object that filters the images o=un the benchmark using this mask. So, we can make a dataset object for a shift as follows:

```python
from dataset_interfaces import utils

root = "./imagenet-star"     # the path where the dataset from HuggingFace
mask_path = "./masks.npy"    # the path to the mask file
shift = "in_the_snow"        # the distribution shift of interest

ds = utils.ImageNet_Star_Dataset(
    root, 
    shift=shift,
    mask_path=mask_path
)
```

## Citation
To cite this paper, please use the following BibTex entry:
```
@inproceedings{vendrow2023dataset,
   title = {Dataset Interfaces: Diagnosing Model Failures Using Controllable Counterfactual Generation},
   author = {Joshua Vendrow and Saachi Jain and Logan Engstrom and Aleksander Madry}, 
   booktitle = {ArXiv preprint arXiv:2302.07865},
   year = {2023}
}
```

## Maintainers:
[Josh Vendrow](https://twitter.com/josh_vendrow)<br>
[Saachi Jain](https://twitter.com/saachi_jain_)
