Skip to content

Commit

Permalink
reformat / add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vgel committed Jul 2, 2024
1 parent d28c70b commit 73177b9
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions repeng/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,22 @@ def batched_get_hiddens(
hidden_states = {layer: [] for layer in hidden_layers}
with torch.no_grad():
for batch in tqdm.tqdm(batched_inputs):
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
# get the last token, handling right padding if present
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt")
encoded_batch = encoded_batch.to(model.device)
out = model(**encoded_batch, output_hidden_states=True)
attention_mask = encoded_batch['attention_mask']
attention_mask = encoded_batch["attention_mask"]
for i in range(len(batch)):
last_non_padding_index = attention_mask[i].nonzero(as_tuple=True)[0][-1].item()
last_non_padding_index = (
attention_mask[i].nonzero(as_tuple=True)[0][-1].item()
)
for layer in hidden_layers:
hidden_idx = layer + 1 if layer >= 0 else layer
hidden_state = out.hidden_states[hidden_idx][i][last_non_padding_index].cpu().numpy()
hidden_state = (
out.hidden_states[hidden_idx][i][last_non_padding_index]
.cpu()
.numpy()
)
hidden_states[layer].append(hidden_state)
del out

Expand Down

0 comments on commit 73177b9

Please sign in to comment.