diff --git a/repeng/extract.py b/repeng/extract.py index 55de4b0..9de957b 100644 --- a/repeng/extract.py +++ b/repeng/extract.py @@ -276,15 +276,23 @@ def batched_get_hiddens( hidden_states = {layer: [] for layer in hidden_layers} with torch.no_grad(): for batch in tqdm.tqdm(batched_inputs): - out = model( - **tokenizer(batch, padding=True, return_tensors="pt").to(model.device), - output_hidden_states=True, - ) - for layer in hidden_layers: - # if not indexing from end, account for embedding hiddens - hidden_idx = layer + 1 if layer >= 0 else layer - for batch in out.hidden_states[hidden_idx]: - hidden_states[layer].append(batch[-1, :].squeeze().cpu().numpy()) + # get the last token, handling right padding if present + encoded_batch = tokenizer(batch, padding=True, return_tensors="pt") + encoded_batch = encoded_batch.to(model.device) + out = model(**encoded_batch, output_hidden_states=True) + attention_mask = encoded_batch["attention_mask"] + for i in range(len(batch)): + last_non_padding_index = ( + attention_mask[i].nonzero(as_tuple=True)[0][-1].item() + ) + for layer in hidden_layers: + hidden_idx = layer + 1 if layer >= 0 else layer + hidden_state = ( + out.hidden_states[hidden_idx][i][last_non_padding_index] + .cpu() + .numpy() + ) + hidden_states[layer].append(hidden_state) del out return {k: np.vstack(v) for k, v in hidden_states.items()} diff --git a/repeng/tests.py b/repeng/tests.py index b0441fc..d32cd77 100644 --- a/repeng/tests.py +++ b/repeng/tests.py @@ -68,13 +68,17 @@ def gen(vector: ControlVector | None, strength_coeff: float | None = None): assert baseline == gen(happy_vector * 0.0) assert baseline == gen(happy_vector - happy_vector) - assert happy == "You are feeling great and happy. I'm" + assert happy == "You are feeling a little more relaxed and enjoying" # these should be identical assert happy == gen(happy_vector, 20.0) assert happy == gen(happy_vector * 20) assert happy == gen(-(happy_vector * -20)) - assert sad == "You are feeling the worst,\n—(" + assert sad == 'You are feeling the fucking damn goddamn worst,"' + # these should be identical + assert sad == gen(happy_vector, -50.0) + assert sad == gen(happy_vector * -50) + assert sad == gen(-(happy_vector * 50)) def test_train_llama_tinystories():