-
Notifications
You must be signed in to change notification settings - Fork 0
/
st6.LLM_eval.py
72 lines (60 loc) · 2.89 KB
/
st6.LLM_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import csv
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import argparse
def escape_prompt(prompt):
# Remove newlines and escape double quotes
prompt = prompt.replace('\n', '').replace('"', '\\"').replace(' . ', '. ')
return prompt
def main(input_tsv_file, output_tsv_file, model_name):
# Device selection: CUDA > MPS > CPU
if torch.cuda.is_available():
device = torch.device("cuda")
print("Using device: CUDA")
device_index = 0 # You can change this if you have multiple GPUs
elif torch.backends.mps.is_available():
device = torch.device("mps")
print("Using device: MPS (Metal Performance Shaders) on Apple Silicon")
device_index = 0 # MPS uses device index 0
else:
device = torch.device("cpu")
print("Using device: CPU")
device_index = -1 # CPU device index
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# Set up the pipeline for text generation
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device_index
)
with open(input_tsv_file, 'r', newline='') as infile, open(output_tsv_file, 'w', newline='') as outfile:
reader = csv.reader(infile, delimiter='\t')
writer = csv.writer(outfile, delimiter='\t')
for row in reader:
if len(row) >= 4:
prompt = '<s>' + escape_prompt(row[3])
# Generate text
response = generator(prompt, max_new_tokens=256)[0]['generated_text']
# Extract only the newly generated text
new_text = response[len(prompt):].strip()
print(new_text)
row.append(new_text)
writer.writerow(row)
else:
print("Row has fewer than 4 columns, skipping:", row)
if __name__ == "__main__":
# Set up command-line argument parsing
parser = argparse.ArgumentParser(description='Process a TSV file using a HuggingFace Transformer model.')
parser.add_argument('--input_tsv_file', type=str, required=True,
help='Path to the input TSV file with LLM prompts.', default='./validation/context_based/16169070_PPI.LLM_input_and_output.tsv')
parser.add_argument('--output_tsv_file', type=str, required=True,
help='Path to the output TSV file.', default='./validation/context_based/16169070_PPI.LLM_output.tsv')
parser.add_argument('--model_name', type=str, required=True,
help='Name or path of the HuggingFace Transformer model.', default='Timofey/Gemma-2-9b-it-Fused_PPI')
args = parser.parse_args()
main(args.input_tsv_file, args.output_tsv_file, args.model_name)