Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: update bart #19

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Ignore files
*.pyc
extract_features_timbackup.py

# Ignore directories
features/
dataset/videos/
__pycache__

101 changes: 87 additions & 14 deletions bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,110 @@
import torch.nn.functional as F
from torch import Tensor, nn
from transformers import T5ForConditionalGeneration, BartForConditionalGeneration
from transformers.modeling_outputs import Seq2SeqLMOutput #TJH added

def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
""" TJH: from modelling_bart.py NOT currently used
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id

assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

return shifted_input_ids



class MyBart(BartForConditionalGeneration):
""" TJH: adding , past_key_values=None to forward(..) takes us to next keyword error 'head_mask'

Original forward below replaced with new forward (and new =model(...) below)
def forward(self, input_ids, attention_mask=None, encoder_outputs=None,
decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None,
use_cache=False, is_training=False):

New version assumes that for training, decoder inputs are in labels
and for generation, decoder inputs are in decoder_input_ids
"""
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None, #TJH In 4.4.2 labels contains what in unifiedqa is called decoder_input_ids
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
#TJH: Added for compatibility with 4.4.2
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if is_training:
decoder_start_token_id = self.config.decoder_start_token_id
_decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape)
_decoder_input_ids[..., 1:] = decoder_input_ids[..., :-1].clone()
_decoder_input_ids[..., 0] = decoder_start_token_id
else:
_decoder_input_ids = decoder_input_ids.clone()

if labels is not None: #TJH added for compatibility with other 4.4.2 seq2seq models
if decoder_input_ids is None:
#TJH: how it is done in modelling_bart.py. Using the unifiedQA method instead
# decoder_input_ids = shift_tokens_right(
# labels, self.config.pad_token_id, self.config.decoder_start_token_id
# )
decoder_start_token_id = self.config.decoder_start_token_id
decoder_input_ids = labels.new_zeros(labels.shape)
decoder_input_ids[..., 1:] = labels[..., :-1].clone()
decoder_input_ids[..., 0] = decoder_start_token_id

# TJH: below from modeling_bart.py
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, #TJH: no underscore
encoder_outputs=encoder_outputs,
decoder_input_ids=_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
if is_training:

loss = None
if labels is not None: #TJH labels is not None instead of is_training
loss_fct = nn.CrossEntropyLoss(reduce=False)
losses = loss_fct(lm_logits.view(-1, self.config.vocab_size),
decoder_input_ids.view(-1))
labels.view(-1))
loss = torch.sum(losses * decoder_attention_mask.float().view(-1))
return loss
return (lm_logits, ) + outputs[1:]

if not return_dict: #TJH: from modeling_bart.py
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return Seq2SeqLMOutput( #TJH: from modeling_bart.py.
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)


def generate_from_string(self, _input, tokenizer=None, **generator_args):
assert tokenizer is not None
Expand Down
6 changes: 4 additions & 2 deletions bart/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ def load_dataset(self, tokenizer, do_return=False):
questions = ["<s> "+question for question in questions]
answers = ["<s> " +answer for answer in answers]
question_input = tokenizer.batch_encode_plus(questions,
pad_to_max_length=True,
truncation=True, #TJH added
padding='max_length', #TJH was pad_to_max_length=True,
max_length=self.args.max_input_length)
answer_input = tokenizer.batch_encode_plus(answers,
pad_to_max_length=True,
truncation=True, #TJH added
padding='max_length', #TJH was pad_to_max_length=True,
max_length=self.args.max_output_length)
input_ids, attention_mask = question_input["input_ids"], question_input["attention_mask"]
decoder_input_ids, decoder_attention_mask = answer_input["input_ids"], answer_input["attention_mask"]
Expand Down
33 changes: 23 additions & 10 deletions bart/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from bart import MyBart

def run(args, logger):
tokenizer = BartTokenizer.from_pretrained("bart-large")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large") #TJH: bart-large

if args.is_unifiedqa:
dev_data = UnifiedQAData(logger, args, args.predict_file, False)
Expand All @@ -31,10 +31,11 @@ def run(args, logger):
train_data.load_dataloader()

if args.checkpoint is not None:
model = MyBart.from_pretrained("bart-large",
state_dict=torch.load(args.checkpoint))
model = MyBart.from_pretrained("facebook/bart-large",
state_dict=torch.load(args.checkpoint)) #TJH: bart-large
logger.info("Loading checkpoint from {}".format(args.checkpoint)) #TJH Added
else:
model = MyBart.from_pretrained("bart-large")
model = MyBart.from_pretrained("facebook/bart-large") #TJH: bart-large
if args.n_gpu>1:
model = torch.nn.DataParallel(model)
if args.n_gpu>0:
Expand All @@ -53,8 +54,8 @@ def run(args, logger):

if args.do_predict:
checkpoint = os.path.join(args.output_dir, 'best-model.pt') if args.checkpoint is None else args.checkpoint
model = MyBart.from_pretrained("bart-large",
state_dict=torch.load(checkpoint))
model = MyBart.from_pretrained("facebook/bart-large",
state_dict=torch.load(checkpoint)) #TJH: bart-large
logger.info("Loading checkpoint from {}".format(checkpoint))
if args.n_gpu>0:
model.to(torch.device("cuda"))
Expand Down Expand Up @@ -83,12 +84,20 @@ def _convert(key):

logger.info("Starting training!")
for epoch in range(int(args.num_train_epochs)):
if args.verbose:
logger.info("Starting Epoch %d" % (epoch)) #TJH added
for batch in train_data.dataloader:
if args.verbose and global_step % 100 == 0:
logger.info("Epoch %d Global Step %d" % (epoch, global_step)) #TJH Added
global_step += 1
batch = [b.to(torch.device("cuda")) for b in batch]
loss = model(input_ids=batch[0], attention_mask=batch[1],
decoder_input_ids=batch[2], decoder_attention_mask=batch[3],
is_training=True)
# TJH: this was the original unifiedqa:
# loss = model(input_ids=batch[0], attention_mask=batch[1],
# decoder_input_ids=batch[2], decoder_attention_mask=batch[3],
# is_training=True)
outputs = model(input_ids=batch[0], attention_mask=batch[1],
labels=batch[2], decoder_attention_mask=batch[3])
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] #TJH added
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if torch.isnan(loss).data:
Expand Down Expand Up @@ -117,6 +126,8 @@ def _convert(key):
torch.save(model_state_dict, os.path.join(args.output_dir,
"best-model-{}.pt".format(str(global_step).zfill(6))))
else:
if args.verbose:
logger.info("Step %d Starting inference.." % (global_step)) #TJH Added
model.eval()
curr_em = inference(model if args.n_gpu==1 else model.module, dev_data)
logger.info("Step %d Train loss %.2f %s %.2f%% on epoch=%d" % (
Expand All @@ -138,8 +149,10 @@ def _convert(key):
stop_training = False
else:
wait_step += 1
logger.info("No improvement. Number of wait steps: %d of max wait steps: %d" % (wait_step, args.wait_step))
if wait_step >= args.wait_step:
stop_training = True
logger.info("Early Stopping due to no improvement after %d wait steps!" % (wait_step)) #TJH Added
break
model.train()
if stop_training:
Expand All @@ -155,7 +168,7 @@ def inference(model, dev_data, save_predictions=False):
outputs = model.generate(input_ids=batch[0],
attention_mask=batch[1],
num_beams=dev_data.args.num_beams,
min_lnegth=1,
min_length=1, #TJH: was min_lnegth
max_length=dev_data.args.max_output_length,
early_stopping=True,)
for input_, output in zip(batch[0], outputs):
Expand Down
14 changes: 14 additions & 0 deletions bart/rundropft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# training command
# other options: --do_predict --skip_inference --debug --checkpoint ${unifiedqa_checkpoint}
# --prefix dev_ --prefix test_ --checkpoint_step

python cli.py --do_train --output_dir /data/thar011/out/unifiedqa_dropft \
--checkpoint /data/thar011/ckpts/unifiedqa-bart-large-allenai/unifiedQA-uncased/best-model.pt \
--is_unifiedqa \
--train_file /data/thar011/data/unifiedqa/train.tsv \
--predict_file /data/thar011/data/unifiedqa/dev.tsv \
--train_batch_size 64 \
--predict_batch_size 64 \
--append_another_bos --do_lowercase \
--eval_period 10000 --verbose

18 changes: 18 additions & 0 deletions bart/runpredict.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#run predictions - picks up best_model from output_dir otherwise can specify --checkpoint

python cli.py --do_predict --output_dir /data/thar011/out/unifiedqa_2gputest_from_uqackpt \
--predict_file /data/thar011/data/unifiedqa/drop/dev.tsv \
--predict_batch_size 64 \
--append_another_bos --do_lowercase \
--verbose \
--prefix dev_drop_


python cli.py --do_predict --output_dir /data/thar011/out/unifiedqa_2gputest_from_uqackpt \
--predict_file /data/thar011/data/unifiedqa/ropes/dev.tsv \
--predict_batch_size 64 \
--append_another_bos --do_lowercase \
--verbose \
--prefix dev_ropes_


17 changes: 17 additions & 0 deletions bart/runtrain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# training command
# other options: --do_predict --skip_inference --debug --checkpoint ${unifiedqa_checkpoint}
# --prefix dev_ --prefix test_
# --checkpoint /data/thar011/ckpts/unifiedqa-bart-large-allenai/unifiedQA-uncased/best-model.pt \

python cli.py --do_train --output_dir /data/thar011/out/unifiedqa_2gputest_from_bart \
--is_unifiedqa \
--train_file /data/thar011/data/unifiedqa/train.tsv \
--predict_file /data/thar011/data/unifiedqa/dev.tsv \
--train_batch_size 32 \
--predict_batch_size 32 \
--append_another_bos --do_lowercase \
--eval_period 10000 --verbose \
--num_train_epochs 10000 \
--gradient_accumulation_steps 2 \
--wait_step 10

14 changes: 9 additions & 5 deletions bart/unified_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def __init__(self, logger, args, data_path, is_training):
"boolq",
"race_string",
"openbookqa"]
self.data_path = data_path
self.data_type = data_path.split("/")[-1][:-4]
self.data_path = data_path #TJH this would be ../unifiedqa/train.tsv
self.data_type = data_path.split("/")[-1][:-4] #TJH strip .tsv from filename appearing after final "/"
assert self.data_type in ["train", "dev", "test"]

if args.debug:
Expand Down Expand Up @@ -96,10 +96,14 @@ def load_dataset(self, tokenizer):
questions = ["<s> "+question for question in questions]
answers = ["<s> " +answer for answer in answers]
question_input = self.tokenizer.batch_encode_plus(questions,
pad_to_max_length=True,
max_length=self.args.max_input_length)
truncation=True, #TJH added
padding='max_length', #TJH was pad_to_max_length=True,
max_length=self.args.max_input_length)
answer_input = self.tokenizer.batch_encode_plus(answers,
pad_to_max_length=True)
truncation=True, #TJH added
padding='max_length', #TJH was pad_to_max_length=True,
max_length=self.args.max_input_length)

input_ids, attention_mask = question_input["input_ids"], question_input["attention_mask"]
decoder_input_ids, decoder_attention_mask = answer_input["input_ids"], answer_input["attention_mask"]
print ("Finish tokenizering...")
Expand Down
78 changes: 78 additions & 0 deletions tjh/basic_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 31 13:05:08 2021

@author: thar011

UnifiedQA initial tests:
T5 checkpoints tests
followed by BART tests

"""

# UnifiedQA T5 checkpoints tests:

from transformers import AutoTokenizer, T5ForConditionalGeneration

# inference on <= 3B models work
#model_name = "allenai/unifiedqa-t5-small" # you can specify the model size here
#model_name = "allenai/unifiedqa-t5-base" # you can specify the model size here
model_name = "allenai/unifiedqa-t5-large" # you can specify the model size here
#model_name = "allenai/unifiedqa-t5-3b" # you can specify the model size here
model_name = "allenai/unifiedqa-t5-11b" # you can specify the model size here
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

def run_model(input_string, **generator_args):
input_ids = tokenizer.encode(input_string, return_tensors="pt")
res = model.generate(input_ids, **generator_args)
return tokenizer.batch_decode(res, skip_special_tokens=True)


run_model("which is best conductor? \\n (a) iron (b) feather") #['iron']
run_model("which is best conductor? \\n ") # ['no answer>']
run_model("which is best conductor - iron or feather? \\n ") # ['iron']
run_model("Name a conductor of electricty? \\n Name any conductor") # ['any conductor']
run_model("Name a conductor of electricty? \\n ") # ['yes']
run_model("Name a conductor of electricity: \\n ") # ['yes']
run_model("What is 53 + 9521? \\n ") # ['no answer>']



run_model("scott filled a tray with juice and put it in a freezer. the next day, scott opened the freezer. how did the juice most likely change? \\n (a) it condensed. (b) it evaporated. (c) it became a gas. (d) it became a solid.")

run_model("which is best conductor? \\n (a) iron (b) feather (c) wood (d) plastic",
temperature=0.9, num_return_sequences=4, num_beams=20)


# BART tests (run from unifiedqa-tjh/bart directory):

import torch
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
from bart import MyBart

base_model = "facebook/bart-large"
#unifiedqa_path = "unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint
unifiedqa_path = "/data/thar011/ckpts/unifiedqa-bart-large-allenai/unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint

tokenizer = BartTokenizer.from_pretrained(base_model)
model = MyBart.from_pretrained(base_model, state_dict=torch.load(unifiedqa_path))
model.eval()

# ERROR: TypeError: forward() got an unexpected keyword argument 'past_key_values'
x = model.generate_from_string("Which is best conductor? \\n (A) iron (B) feather", tokenizer=tokenizer)
print(x)

x = model.generate_from_string("What is the sum of 3 and 5? \\n (A) 8 (B) 3 (C) 5 (D) 10", tokenizer=tokenizer)
print(x)


#try basic bart model (no error):
model = BartForConditionalGeneration.from_pretrained(base_model)
model.eval()
run_model("which is best conductor? \\n (a) iron (b) feather") #['whichwhich is best conductor?']




Loading