diff --git a/repeng/extract.py b/repeng/extract.py index 10e437d..9de957b 100644 --- a/repeng/extract.py +++ b/repeng/extract.py @@ -276,14 +276,22 @@ def batched_get_hiddens( hidden_states = {layer: [] for layer in hidden_layers} with torch.no_grad(): for batch in tqdm.tqdm(batched_inputs): - encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device) + # get the last token, handling right padding if present + encoded_batch = tokenizer(batch, padding=True, return_tensors="pt") + encoded_batch = encoded_batch.to(model.device) out = model(**encoded_batch, output_hidden_states=True) - attention_mask = encoded_batch['attention_mask'] + attention_mask = encoded_batch["attention_mask"] for i in range(len(batch)): - last_non_padding_index = attention_mask[i].nonzero(as_tuple=True)[0][-1].item() + last_non_padding_index = ( + attention_mask[i].nonzero(as_tuple=True)[0][-1].item() + ) for layer in hidden_layers: hidden_idx = layer + 1 if layer >= 0 else layer - hidden_state = out.hidden_states[hidden_idx][i][last_non_padding_index].cpu().numpy() + hidden_state = ( + out.hidden_states[hidden_idx][i][last_non_padding_index] + .cpu() + .numpy() + ) hidden_states[layer].append(hidden_state) del out