-
Notifications
You must be signed in to change notification settings - Fork 0
/
ppl_data.py
54 lines (42 loc) · 1.82 KB
/
ppl_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
'''
Some of the code refer to
https://github.com/IST-DASLab/gptq/blob/main/datautils.py
'''
import random
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data.dataset import Dataset
def get_wikitext2(seq_len, tokenizer):
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
return traindata, testdata
def get_ptb(seq_len, tokenizer):
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
return traindata, valdata
class IndexDataset(Dataset):
def __init__(self, tensors):
self.tensors = tensors
def __getitem__(self, index):
return self.tensors[index]
def __len__(self):
return len(self.tensors)
def process_data(samples, tokenizer, seq_len, field_name):
test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0]
test_ids_batch = []
nsamples = test_ids.numel() // seq_len
for i in range(nsamples):
batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
test_ids_batch.append(batch)
test_ids_batch = torch.stack(test_ids_batch)
return IndexDataset(tensors=test_ids_batch)
def get_loaders(name, tokenizer, seq_len=2048, batch_size = 8):
if 'wikitext2' in name:
train_data, test_data = get_wikitext2(seq_len, tokenizer)
test_dataset = process_data(test_data, tokenizer, seq_len, 'text')
if 'ptb' in name:
train_data, test_data = get_ptb(seq_len, tokenizer)
test_dataset = process_data(test_data, tokenizer, seq_len, 'sentence')
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_data, test_loader