Skip to content

Commit

Permalink
add synthetic neurons
Browse files Browse the repository at this point in the history
  • Loading branch information
Tamar Rott Shaham committed Aug 14, 2024
1 parent 8b489eb commit 43cbf65
Show file tree
Hide file tree
Showing 13 changed files with 5,347 additions and 31 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "synthetic-neurons-dataset/Grounded-Segment-Anything"]
path = synthetic-neurons-dataset/Grounded-Segment-Anything
url = https://github.com/IDEA-Research/Grounded-Segment-Anything.git
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ MAIA is a system that uses neural models to automate neural model understanding
**News**
\
[July 3]: We release MAIA implementation code for neuron labeling
\
[August 14]: Synthetic neurons are now available (both in `demo.ipynb` and in `main.py`)

**This repo is under active development. Sign up for updates by email using [this google form](https://forms.gle/Zs92DHbs3Y3QGjXG6).**

Expand Down Expand Up @@ -45,6 +47,8 @@ jupyter notebook
```
This command will start the Jupyter Notebook server and open the Jupyter Notebook interface in your default web browser. The interface will show all the notebooks, files, and subdirectories in this repo (assuming is was initiated from the maia path). Open ```demo.ipynb``` and proceed according to the instructions.

NEW: `demo.ipynb` now supports synthetic neurons. Follow instalation instructions at `./synthetic-neurons-dataset/README.md`. After installation is done, you can define MAIA to run on synthetic neurons according to the instructions in `demo.ipynb`.

### Batch experimentation ###
To run a batch of experiments, use ```main.py```:

Expand Down Expand Up @@ -73,3 +77,20 @@ Results are automatically saved to an html file under ```./results/``` and can b
python -m http.server 80
```
Once the server is up, open the html in [http://localhost:80](http://localhost:80/results/)

#### Run MAIA on sythetic neurons ####

You can now run maia on synthetic neurons with ground-truth labels (see sec. 4.2 in the paper for more details).

Follow instalation instructions at `./synthetic-neurons-dataset/README.md`. Then you should be able to run `main.py` on synthetic neurons by calling e.g.:
```bash
python main.py --model synthetic_neurons --unit_mode manual --units mono=1,8:or=9:and=0,2,5
```
(neuron indices are specified according to the neuron type: "mono", "or" and "and").

You can also use the .json file to run all synthetic neurons (or specify your own file):
```bash
python main.py --model synthetic_neurons --unit_mode from_file --unit_file_path ./neuron_indices/
```
### Acknowledgment ###
[Christy](https://www.linkedin.com/in/christykl/) helped with cleaning up the synthetic neurons code for release.
1,803 changes: 1,782 additions & 21 deletions demo.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions download_exemplars.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ wget -P ./exemplars data.csail.mit.edu/maia-exemplars/dino.zip
unzip ./exemplars/dino.zip -d ./exemplars/
wget -P ./exemplars data.csail.mit.edu/maia-exemplars/resnet152.zip
unzip ./exemplars/resnet152.zip -d ./exemplars/
wget -P ./exemplars data.csail.mit.edu/maia-exemplars/synthetic_neurons.zip
unzip ./exemplars/synthetic_neurons.zip -d ./exemplars/
6 changes: 3 additions & 3 deletions maia_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from call_agent import ask_agent
import time
import math
# import synthetic_neurons
import synthetic_neurons
import clip
import torch.nn.functional as F

Expand Down Expand Up @@ -864,7 +864,7 @@ def describe_images(self, image_list: List[str], image_title:List[str]) -> str:
"""
description_list = ''
instructions = "Do not describe the full image. Please describe ONLY the unmasked regions in this image (e.g. the regions that are not darkened). Be as concise as possible. Return your description in the following format: [highlighted regions]: <your concise description>"
time.sleep(60)
# time.sleep(60)
for ind,image in enumerate(image_list):
history = [{'role':'system', 'content':'you are an helpful assistant'},{'role': 'user', 'content': [{"type":"text", "text": instructions}, {"type": "image_url", "image_url": "data:image/jpeg;base64," + image}]}]
description = ask_agent('gpt-4-vision-preview',history)
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def net_dissect(self,layer,im_size=224):

class SyntheticExemplars():

def __init__(self, path2exemplars, n_exemplars, path2save, mode, im_size=224):
def __init__(self, path2exemplars, path2save, mode, n_exemplars=15, im_size=224):
self.path2exemplars = path2exemplars
self.n_exemplars = n_exemplars
self.path2save = path2save
Expand Down
25 changes: 18 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
layers = {
"resnet152": ['conv1','layer1','layer2','layer3','layer4'],
"clip-RN50" : ['layer1','layer2','layer3','layer4'],
"dino_vits8": ['blocks.1.mlp.fc1','blocks.3.mlp.fc1','blocks.5.mlp.fc1','blocks.7.mlp.fc1','blocks.9.mlp.fc1','blocks.11.mlp.fc1']
"dino_vits8": ['blocks.1.mlp.fc1','blocks.3.mlp.fc1','blocks.5.mlp.fc1','blocks.7.mlp.fc1','blocks.9.mlp.fc1','blocks.11.mlp.fc1'],
"synthetic_neurons": ['mono','or','and']
}

def call_argparse():
parser = argparse.ArgumentParser(description='Process Arguments')
parser.add_argument('--maia', type=str, default='gpt-4-vision-preview', choices=['gpt-4-vision-preview','gpt-4-turbo'], help='maia agent name')
parser.add_argument('--task', type=str, default='neuron_description', choices=['neuron_description'], help='task to solve, default is neuron description') #TODO: add other tasks
parser.add_argument('--model', type=str, default='resnet152', choices=['resnet152','clip-RN50','dino_vits8'], help='model to interp') #TODO: add synthetic neurons
parser.add_argument('--model', type=str, default='resnet152', choices=['resnet152','clip-RN50','dino_vits8','synthetic_neurons'], help='model to interp') #TODO: add synthetic neurons
parser.add_argument('--units', type=str2dict, default='layer4=122', help='units to interp')
parser.add_argument('--unit_mode', type=str, default='manual', choices=['from_file','random','manual'], help='units to interp')
parser.add_argument('--unit_file_path', type=str, default='./neuron_indices/', help='units to interp')
Expand All @@ -49,7 +50,6 @@ def str2dict(arg_value):
key, value = item.split('=')
values = value.split(',')
my_dict[key] = [int(v) for v in values]
embed()
return my_dict


Expand Down Expand Up @@ -152,16 +152,27 @@ def interpretation_experiment(maia,system,tools,debug=False):
def main(args):
maia_api, user_query = return_Prompt(args.path2prompts, setting=args.task) # load system prompt (maia api) and user prompt (the user query)
unit_inx = units2explore(args.unit_mode) # returns a dictionary of {'layer':[units]} to explore
for layer in unit_inx.keys():
for layer in unit_inx.keys(): # for the synthetic neurons, the layer is the neuron type ("mono", "or", "and")
units = unit_inx[layer]
net_dissect = DatasetExemplars(args.path2exemplars, args.path2save, args.model, layer, units) # precomputes dataset examplars for tools.dataset_exemplars
if args.model == 'synthetic_neurons':
net_dissect = SyntheticExemplars(os.path.join(args.path2exemplars, args.model), args.path2save, layer) # precomputes synthetic dataset examplars for tools.dataset_exemplars.
with open(os.path.join('./synthetic-neurons-dataset/labels/',f'{layer}.json'), 'r') as file: # load the synthetic neuron labels
synthetic_neuron_data = json.load(file)
else:
net_dissect = DatasetExemplars(args.path2exemplars, args.path2save, args.model, layer, units) # precomputes dataset examplars for tools.dataset_exemplars
for unit in units:
print(layer,unit)
path2save = os.path.join(args.path2save,args.maia,args.model,str(layer),str(unit))
if os.path.exists(path2save+'/description.txt'): continue
os.makedirs(path2save, exist_ok=True)
system = System(unit, layer, args.model, args.device, net_dissect.thresholds) # initialize the system class
tools = Tools(path2save, args.device, net_dissect, text2image_model_name=args.text2image) # initialize the tools class
if args.model == 'synthetic_neurons':
gt_label = synthetic_neuron_data[unit]["label"].rsplit('_')
print("groundtruth label: ",gt_label)
system = Synthetic_System(unit, gt_label, layer, args.device)
tools = Tools(path2save, args.device, net_dissect, text2image_model_name=args.text2image, images_per_prompt=1) # initialize the tools class
else:
system = System(unit, layer, args.model, args.device, net_dissect.thresholds) # initialize the system class
tools = Tools(path2save, args.device, net_dissect, text2image_model_name=args.text2image) # initialize the tools class
tools.update_experiment_log(role='system', type="text", type_content=maia_api) # update the experiment log with the system prompt
tools.update_experiment_log(role='user', type="text", type_content=user_query) # update the experiment log with the user prompt
interp_count = 0
Expand Down
1 change: 1 addition & 0 deletions neuron_indices/synthetic_neurons.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"mono": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46], "or": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], "and": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]}
1 change: 1 addition & 0 deletions synthetic-neurons-dataset/Grounded-Segment-Anything
60 changes: 60 additions & 0 deletions synthetic-neurons-dataset/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Setup Instructions for Synthetic Neurons

## To set up the synthetic neurons:

1. **init Grounded-SAM submodule**
```
git submodule init
git submodule update
```

2. **Follow the setup instructions on Grounded-SAM setup:**
- Export global variables (choose whether to run on CPU or GPU; note that running on CPU is feasible but slower, approximately 3 seconds per image):
```bash
export AM_I_DOCKER=False
export BUILD_WITH_CUDA=True
export CUDA_HOME=/path/to/cuda-11.3/
```
- Install Segment Anything:
```bash
python -m pip install -e segment_anything
```
- Install Grounding Dino:
```bash
python -m pip install -e GroundingDINO
```
- Install diffusers:
```bash
pip install --upgrade diffusers[torch]
```
- Install osx:
```bash
git submodule update --init --recursive
cd grounded-sam-osx && bash install.sh
```

3. **Download grounded DINO and grounded SAM .pth files**
- Download groudned DINO:
```bash
cd .. #back to ./Grounded_Segment-Anything
#download the pretrained groundingdino-swin-tiny model
wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
```
- Download grounded SAM:
```bash
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
```
- Try running grounded SAM demo:
```bash
export CUDA_VISIBLE_DEVICES=0
python grounded_sam_demo.py \
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
--grounded_checkpoint groundingdino_swint_ogc.pth \
--sam_checkpoint sam_vit_h_4b8939.pth \
--input_image assets/demo1.jpg \
--output_dir "outputs" \
--box_threshold 0.3 \
--text_threshold 0.25 \
--text_prompt "bear" \
--device "cpu"
```
Loading

0 comments on commit 43cbf65

Please sign in to comment.