Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix evaluation during training for t5 #1551

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions simpletransformers/t5/t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@
import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm, trange
from transformers.models.byt5 import ByT5Tokenizer
from transformers.models.mt5 import MT5Config, MT5ForConditionalGeneration
from transformers.models.t5 import T5Config, T5ForConditionalGeneration, T5Tokenizer
from transformers.optimization import (
Adafactor,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
)
from torch.optim import AdamW
from transformers.optimization import Adafactor
from transformers.models.mt5 import MT5Config, MT5ForConditionalGeneration
from transformers.models.byt5 import ByT5Tokenizer

from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import T5Args
Expand Down Expand Up @@ -926,26 +926,24 @@ def eval_model(
to_predict = [
prefix + ": " + input_text
for prefix, input_text in zip(
eval_dataset["prefix"], eval_dataset["input_text"]
eval_data["prefix"], eval_data["input_text"]
)
]
else:
to_predict = [
prefix + input_text
for prefix, input_text in zip(
eval_dataset["prefix"], eval_dataset["input_text"]
eval_data["prefix"], eval_data["input_text"]
)
]
preds = self.predict(to_predict)

if self.args.use_hf_datasets:
target_text = eval_dataset["target_text"]
target_text = eval_data["target_text"]
else:
target_text = eval_dataset["target_text"].tolist()
target_text = eval_data["target_text"].tolist()

result = self.compute_metrics(
target_text, preds, **kwargs
)
result = self.compute_metrics(target_text, preds, **kwargs)
self.results.update(result)

if verbose:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def test_t5():
"save_model_every_epoch": False,
"max_length": 20,
"num_beams": 1,
"evaluate_generated_text": True,
"evaluate_during_training": True,
}

# Create T5 Model
model = T5Model("t5", "t5-base", args=model_args, use_cuda=False)

# Train T5 Model on new task
model.train_model(train_df)
model.train_model(train_df, eval_data=eval_data)

# Evaluate T5 Model on new task
model.eval_model(eval_df)
Expand Down