Skip to content

Commit

Permalink
Reset extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle1668 committed Apr 22, 2023
1 parent d61375c commit 69194ee
Showing 1 changed file with 107 additions and 96 deletions.
203 changes: 107 additions & 96 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Any, Iterable, Literal
from warnings import filterwarnings

import torch
from datasets import (
Array2D,
Array3D,
ClassLabel,
DatasetDict,
DownloadMode,
Features,
Sequence,
SplitDict,
Expand All @@ -20,23 +22,23 @@
)
from simple_parsing import Serializable, field
from torch import Tensor
from transformers import AutoConfig, AutoTokenizer, GPT2TokenizerFast
from transformers import AutoConfig, PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput

from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
convert_span,
float32_to_int16,
infer_label_column,
infer_num_classes,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
select_train_val_splits,
select_usable_devices,
)
from .balanced_sampler import BalancedSampler
from .generator import _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts
from ..rwkv_lm.rwkv_hf import RWKVConfig


@dataclass
Expand All @@ -58,6 +60,7 @@ class Extract(Serializable):
layers: tuple[int, ...] = ()
layer_stride: InitVar[int] = 1
token_loc: Literal["first", "last", "mean"] = "last"
use_encoder_states: bool = False

def __post_init__(self, layer_stride: int):
if self.layers and layer_stride > 1:
Expand Down Expand Up @@ -85,7 +88,7 @@ def explode(self) -> list["Extract"]:
return copies


@torch.no_grad()
@torch.inference_mode()
def extract_hiddens(
cfg: "Extract",
*,
Expand All @@ -99,135 +102,135 @@ def extract_hiddens(

# Silence datasets logging messages from all but the first process
if rank != 0:
filterwarnings("ignore")
logging.disable(logging.CRITICAL)

ds_names = cfg.prompts.datasets
p_cfg = cfg.prompts
ds_names = p_cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

prompt_ds = load_prompts(
ds_names[0],
split_type=split_type,
stream=cfg.prompts.stream,
rank=rank,
world_size=world_size,
) # this dataset is already sharded, but hasn't been truncated to max_examples

model = instantiate_model(
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
).to(device)
tokenizer = None
tokenizer = instantiate_tokenizer(
cfg.model, truncation_side="left", verbose=rank == 0
)

if cfg.model.startswith("RWKV"):
tokenizer = GPT2TokenizerFast(tokenizer_file='/home/kyle/repos/elk/elk/rwkv_lm/20B_tokenizer.json')
else:
tokenizer = AutoTokenizer.from_pretrained(
cfg.model, truncation_side="left", verbose=False
)
is_enc_dec = model.config.is_encoder_decoder
if is_enc_dec and cfg.use_encoder_states:
assert hasattr(model, "get_encoder") and callable(model.get_encoder)
model = assert_type(PreTrainedModel, model.get_encoder())
is_enc_dec = False

has_lm_preds = is_autoregressive(model.config)
has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states)
if has_lm_preds and rank == 0:
print("Model has language model head, will store predictions.")

prompt_ds = load_prompts(
ds_names[0],
label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None,
num_classes=p_cfg.num_classes,
split_type=split_type,
stream=p_cfg.stream,
rank=rank,
world_size=world_size,
)

# Iterating over questions
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))

global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
# the last process gets the remainder (which is usually small)
if rank == world_size - 1:
max_examples += global_max_examples % world_size

for example in islice(BalancedSampler(prompt_ds), max_examples):
for example in islice(prompt_ds, max_examples):
num_variants = len(example["prompts"])
num_choices = len(example["prompts"][0])

hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
num_variants,
2, # contrast pair
num_choices,
model.config.hidden_size,
device=device,
dtype=torch.int16,
)
for layer_idx in layer_indices
}
lm_preds = torch.empty(
lm_logits = torch.empty(
num_variants,
2, # contrast pair
num_choices,
device=device,
dtype=torch.float32,
)
text_inputs = []
text_questions = []

# Iterate over variants
for i, record in enumerate(example["prompts"]):
variant_inputs = []
variant_questions = []

# Iterate over answers
for j, choice in enumerate(record):
text = choice["text"]

# TODO: Do something smarter than "rindex" here. Really we want to
# get the span of the answer directly from Jinja, but that doesn't
# seem possible. This approach may fail for complex templates.
answer_start = text.rindex(choice["answer"])
text = choice["question"]

# Only feed question, not the answer, to the encoder for enc-dec models
if model.config.is_encoder_decoder:
# TODO: Maybe make this more generic for complex templates?
text = text[:answer_start].rstrip()
target = choice["answer"]
else:
target = None

# Record the EXACT string we fed to the model
variant_inputs.append(text)
# inputs = None
# if cfg.model.startswith("RWKV"):
# inputs = tokenizer(
# text,
# return_offsets_mapping=True,
# text_target=target, # type: ignore[arg-type]
# truncation=True,
# )
# else:
inputs = tokenizer(
target = choice["answer"] if is_enc_dec else None

# Record the EXACT question we fed to the model
variant_questions.append(text)
encoding = tokenizer(
text,
return_offsets_mapping=True,
add_special_tokens=False,
return_tensors="pt",
text_target=target, # type: ignore[arg-type]
truncation=True,
)
).to(device)
input_ids = assert_type(Tensor, encoding.input_ids)

if is_enc_dec:
answer = assert_type(Tensor, encoding.labels)
else:
encoding2 = tokenizer(
choice["answer"],
add_special_tokens=False,
return_tensors="pt",
).to(device)
answer = assert_type(Tensor, encoding2.input_ids)

# The offset_mapping is a sorted list of (start, end) tuples. We locate
# the start of the answer in the tokenized sequence with binary search.
offsets = inputs.pop("offset_mapping") if cfg.model.startswith("RWKV") else inputs.pop("offset_mapping").squeeze().tolist()
inputs = inputs if cfg.model.startswith("RWKV") else inputs.to(device)
input_ids = torch.cat([input_ids, answer], dim=-1)
if max_len := tokenizer.model_max_length:
input_ids = input_ids[..., -max_len:]

# Run the forward pass
outputs = model(**inputs) if cfg.model.startswith("RWKV") else model(**inputs, output_hidden_states=True)
# Make sure we only pass the arguments that the model expects
inputs = dict(input_ids=input_ids)
if is_enc_dec:
inputs["labels"] = answer

with torch.autocast("cuda", enabled=torch.cuda.is_available()):
outputs = model(**inputs, output_hidden_states=True)

# Compute the log probability of the answer tokens if available
if has_lm_preds:
start, end = convert_span(
offsets, (answer_start, answer_start + len(choice["answer"]))
)
log_p = outputs.logits[..., start - 1 : end - 1, :].log_softmax(
dim=-1
)
tokens = inputs.input_ids[..., start:end, None]
lm_preds[i, j] = log_p.gather(-1, tokens).sum()
answer_len = answer.shape[-1]

log_p = outputs.logits[..., -answer_len:, :].log_softmax(dim=-1)
tokens = answer[..., None]
lm_logits[i, j] = log_p.gather(-1, tokens).sum()

elif isinstance(outputs, Seq2SeqLMOutput):
# The cross entropy loss is averaged over tokens, so we need to
# multiply by the length to get the total log probability.
length = inputs.labels.shape[-1]
lm_preds[i, j] = -assert_type(Tensor, outputs.loss) * length
length = encoding.labels.shape[-1]
lm_logits[i, j] = -assert_type(Tensor, outputs.loss) * length

hiddens = outputs if cfg.model.startswith("RWKV") else (
hiddens = (
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
)
# First element of list is the input embeddings
hiddens = hiddens if cfg.model.startswith("RWKV") else hiddens[1:]
hiddens = hiddens[1:]

# Throw out layers we don't care about
hiddens = [hiddens[i] for i in layer_indices]
Expand All @@ -245,17 +248,16 @@ def extract_hiddens(
for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden)

text_inputs.append(variant_inputs)
text_questions.append(variant_questions)

out_record: dict[str, Any] = dict(
label=example["label"],
variant_ids=example["template_names"],
text_inputs=text_inputs,
text_questions=text_questions,
**hidden_dict,
)
if has_lm_preds:
# We only need the probability of the positive example since this is binary
out_record["model_preds"] = lm_preds.softmax(dim=-1)[..., 1]
out_record["model_logits"] = lm_logits

yield out_record

Expand All @@ -266,7 +268,11 @@ def _extraction_worker(**kwargs):


def extract(
cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None
cfg: "Extract",
*,
disable_cache: bool = False,
num_gpus: int = -1,
min_gpu_mem: int | None = None,
) -> DatasetDict:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""

Expand All @@ -292,15 +298,18 @@ def get_splits() -> SplitDict:
dataset_name=available_splits.dataset_name,
)

model_cfg = None
if cfg.model.startswith("RWKV"):
model_cfg = RWKVConfig()
else:
model_cfg = AutoConfig.from_pretrained(cfg.model)
model_cfg = AutoConfig.from_pretrained(cfg.model)

ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)

ds_features = assert_type(Features, info.features)
label_col = (
cfg.prompts.label_columns[0]
if cfg.prompts.label_columns
else infer_label_column(ds_features)
)
num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col])
num_variants = cfg.prompts.num_variants
if num_variants < 0:
prompter = DatasetTemplates(ds_name, config_name)
Expand All @@ -309,7 +318,7 @@ def get_splits() -> SplitDict:
layer_cols = {
f"hidden_{layer}": Array3D(
dtype="int16",
shape=(num_variants, 2, model_cfg.hidden_size),
shape=(num_variants, num_classes, model_cfg.hidden_size),
)
for layer in cfg.layers or range(model_cfg.num_hidden_layers)
}
Expand All @@ -318,21 +327,20 @@ def get_splits() -> SplitDict:
Value(dtype="string"),
length=num_variants,
),
"label": ClassLabel(names=["neg", "pos"]),
"text_inputs": Sequence(
"label": Value(dtype="int64"),
"text_questions": Sequence(
Sequence(
Value(dtype="string"),
length=2,
),
length=num_variants,
),
}

# Only add model_preds if the model is an autoregressive model
if is_autoregressive(model_cfg):
other_cols["model_preds"] = Sequence(
Value(dtype="float32"),
length=num_variants,
# Only add model_logits if the model is an autoregressive model
if is_autoregressive(model_cfg, not cfg.use_encoder_states):
other_cols["model_logits"] = Array2D(
shape=(num_variants, num_classes),
dtype="float32",
)

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
Expand Down Expand Up @@ -361,7 +369,10 @@ def get_splits() -> SplitDict:

ds = dict()
for split, builder in builders.items():
builder.download_and_prepare(num_proc=len(devices))
builder.download_and_prepare(
download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None,
num_proc=len(devices),
)
ds[split] = builder.as_dataset(split=split)

return DatasetDict(ds)
return DatasetDict(ds)

0 comments on commit 69194ee

Please sign in to comment.