diff --git a/experiments/util/models.py b/experiments/util/models.py index 223d2c7..7adc69a 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -26,9 +26,8 @@ AutoModelForSeq2SeqLM, BartForConditionalGeneration, ) -from transformers import ( - logging as tf_logging, -) +from transformers import logging as tf_logging +from transformers import utils as tf_utils from memsave_torch.nn import ( MemSaveConv2d, @@ -349,6 +348,8 @@ def __init__( tf_logging.disable_progress_bar() tf_logging.set_verbosity_error() +tf_utils.logging.captureWarnings(True) + hf_transformers_models_map = { "gpt2": _HF_model("gpt2", {}, lm_head_name="lm_head"),