From 17991824af0b397a1c1f9e44e3ee3a2ce67835a7 Mon Sep 17 00:00:00 2001 From: Artrajz <969242373@qq.com> Date: Sat, 4 Nov 2023 17:50:20 +0800 Subject: [PATCH] Update: Unload the BERT model. Ignore loading warning. --- ModelManager.py | 5 +++++ bert_vits2/text/bert_handler.py | 13 ++++++++----- bert_vits2/utils.py | 6 ++++++ 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/ModelManager.py b/ModelManager.py index 3659f96..873b380 100644 --- a/ModelManager.py +++ b/ModelManager.py @@ -253,6 +253,10 @@ def unload_model(self, model_type_value: str, model_id: str): if key == model_id: break start += ns + + if model_type == ModelType.BERT_VITS2: + for bert_model_name in self.models[model_type][model_id][1].bert_model_names.values(): + self.bert_handler.release_bert(bert_model_name) del self.sid2model[model_type][start:start + n_speakers] del self.voice_speakers[model_type.value][start:start + n_speakers] @@ -266,6 +270,7 @@ def unload_model(self, model_type_value: str, model_id: str): state = True self.notify("model_unloaded", model_manager=self) + self.logger.info(f"Unloading success.") except Exception as e: self.logger.info(f"Unloading failed. {e}") state = False diff --git a/bert_vits2/text/bert_handler.py b/bert_vits2/text/bert_handler.py index 5261e3a..aa705e6 100644 --- a/bert_vits2/text/bert_handler.py +++ b/bert_vits2/text/bert_handler.py @@ -1,10 +1,11 @@ +import gc import logging import os +import torch from transformers import AutoTokenizer, AutoModelForMaskedLM from utils.config_manager import global_config as config -from logger import logger from utils.download import download_file from .chinese_bert import get_bert_feature as zh_bert from .english_bert_mock import get_bert_feature as en_bert @@ -53,7 +54,7 @@ def _download_model(self, bert_model_name, target_path=None): "https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/spm.model", ], } - + SHA256 = { "CHINESE_ROBERTA_WWM_EXT_LARGE": "4ac62d49144d770c5ca9a5d1d3039c4995665a080febe63198189857c6bd11cd", "BERT_BASE_JAPANESE_V3": "e172862e0674054d65e0ba40d67df2a4687982f589db44aa27091c386e5450a4", @@ -70,9 +71,9 @@ def _download_model(self, bert_model_name, target_path=None): expected_sha256 = SHA256[bert_model_name] success, message = download_file(urls, target_path, expected_sha256=expected_sha256) if not success: - logger.error(f"Failed to download {bert_model_name}: {message}") + logging.error(f"Failed to download {bert_model_name}: {message}") else: - logger.info(f"{message}") + logging.info(f"{message}") def load_bert(self, bert_model_name, max_retries=3): if bert_model_name not in self.bert_models: @@ -119,7 +120,9 @@ def release_bert(self, bert_model_name): if count == 0: # 当引用计数为0时,删除模型并释放其资源 del self.bert_models[bert_model_name] - logger(f"Model {bert_model_name} has been released.") + gc.collect() + torch.cuda.empty_cache() + logging.info(f"BERT model {bert_model_name} has been released.") else: tokenizer, model = self.bert_models[bert_model_name][:2] self.bert_models[bert_model_name] = (tokenizer, model, count) diff --git a/bert_vits2/utils.py b/bert_vits2/utils.py index 5ac6252..1f850d1 100644 --- a/bert_vits2/utils.py +++ b/bert_vits2/utils.py @@ -47,6 +47,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False logger.error(f"{k} is not in the checkpoint") logger.warning( f"If you're using a transitional version, please add the \"version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.") + elif "en_bert_proj" in k: + v = torch.zeros_like(v) + if version is None: + logger.error(f"{k} is not in the checkpoint") + logger.warning( + f"If you're using an older version of the model, consider adding the \"version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.1.1\"") else: logger.error(f"{k} is not in the checkpoint")