Skip to content

Commit

Permalink
fixes for vis
Browse files Browse the repository at this point in the history
minor
  • Loading branch information
plutonium-239 committed May 7, 2024
1 parent c0fdf21 commit 4410b5f
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions experiments/util/visualize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,17 @@
from transformers import AutoConfig

to_test = {
'bert_encoder': lambda: models.transformer_model_fns['bert']().bert.encoder.layer[0],
'memsave_bert_encoder': lambda: models.transformer_model_fns['memsave_bert']().bert.encoder.layer[0],
'bart_encoder': lambda: models.transformer_model_fns['bart']().decoder.layers[0],
'memsave_bart_encoder': lambda: models.transformer_model_fns['memsave_bart']().decoder.layers[0],
'gpt2_layer': lambda: models.transformer_model_fns['gpt2']().transformer.h[0],
'memsave_gpt2_layer': lambda: models.transformer_model_fns['memsave_gpt2']().transformer.h[0],
't5_decoder': lambda: models.transformer_model_fns['t5']().decoder.block[1],
'memsave_t5_decoder': lambda: models.transformer_model_fns['memsave_t5']().decoder.block[1],
'bert_encoder': ['bert', lambda model: model.bert.encoder.layer[0]],
'memsave_bert_encoder': ['memsave_bert', lambda model: model.bert.encoder.layer[0]],
'bart_encoder': ['bart', lambda model: model.decoder.layers[0]],
'memsave_bart_encoder': ['memsave_bart', lambda model: model.decoder.layers[0]],
'gpt2_layer': ['gpt2', lambda model: model.transformer.h[0]],
'memsave_gpt2_layer': ['memsave_gpt2', lambda model: model.transformer.h[0]],
't5_decoder': ['t5', lambda model: model.decoder.block[1]],
'memsave_t5_decoder': ['memsave_t5', lambda model: model.decoder.block[1]],
}

def run_single(model_fn, name, x):
model = model_fn()

def run_single(model, name, x):
y = model(x)
dot = make_dot(
y.mean(),
Expand All @@ -47,8 +45,13 @@ def run_single(model_fn, name, x):
# models.conv_input_shape = (3, 64, 64)
models.transformer_input_shape = (5000, 1024)

for name,model_fn in to_test.items():
for name in to_test:
model_name, block_fn = to_test[name]
config = models.get_transformers_config(model_name)

models.transformer_input_shape = (config.vocab_size, config.hidden_size)
x = torch.rand(7, *models.transformer_input_shape)

run_single(model_fn, name, x)
model = models.transformer_model_fns.get(model_name)
run_single(block_fn(model()), name, x)

0 comments on commit 4410b5f

Please sign in to comment.