Skip to content

Commit

Permalink
Update:
Browse files Browse the repository at this point in the history
Unload the BERT model.
Ignore loading warning.
  • Loading branch information
Artrajz committed Nov 4, 2023
1 parent 70d4b45 commit 1799182
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
5 changes: 5 additions & 0 deletions ModelManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ def unload_model(self, model_type_value: str, model_id: str):
if key == model_id:
break
start += ns

if model_type == ModelType.BERT_VITS2:
for bert_model_name in self.models[model_type][model_id][1].bert_model_names.values():
self.bert_handler.release_bert(bert_model_name)

del self.sid2model[model_type][start:start + n_speakers]
del self.voice_speakers[model_type.value][start:start + n_speakers]
Expand All @@ -266,6 +270,7 @@ def unload_model(self, model_type_value: str, model_id: str):

state = True
self.notify("model_unloaded", model_manager=self)
self.logger.info(f"Unloading success.")
except Exception as e:
self.logger.info(f"Unloading failed. {e}")
state = False
Expand Down
13 changes: 8 additions & 5 deletions bert_vits2/text/bert_handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import gc
import logging
import os

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

from utils.config_manager import global_config as config
from logger import logger
from utils.download import download_file
from .chinese_bert import get_bert_feature as zh_bert
from .english_bert_mock import get_bert_feature as en_bert
Expand Down Expand Up @@ -53,7 +54,7 @@ def _download_model(self, bert_model_name, target_path=None):
"https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/spm.model",
],
}

SHA256 = {
"CHINESE_ROBERTA_WWM_EXT_LARGE": "4ac62d49144d770c5ca9a5d1d3039c4995665a080febe63198189857c6bd11cd",
"BERT_BASE_JAPANESE_V3": "e172862e0674054d65e0ba40d67df2a4687982f589db44aa27091c386e5450a4",
Expand All @@ -70,9 +71,9 @@ def _download_model(self, bert_model_name, target_path=None):
expected_sha256 = SHA256[bert_model_name]
success, message = download_file(urls, target_path, expected_sha256=expected_sha256)
if not success:
logger.error(f"Failed to download {bert_model_name}: {message}")
logging.error(f"Failed to download {bert_model_name}: {message}")
else:
logger.info(f"{message}")
logging.info(f"{message}")

def load_bert(self, bert_model_name, max_retries=3):
if bert_model_name not in self.bert_models:
Expand Down Expand Up @@ -119,7 +120,9 @@ def release_bert(self, bert_model_name):
if count == 0:
# 当引用计数为0时,删除模型并释放其资源
del self.bert_models[bert_model_name]
logger(f"Model {bert_model_name} has been released.")
gc.collect()
torch.cuda.empty_cache()
logging.info(f"BERT model {bert_model_name} has been released.")
else:
tokenizer, model = self.bert_models[bert_model_name][:2]
self.bert_models[bert_model_name] = (tokenizer, model, count)
Expand Down
6 changes: 6 additions & 0 deletions bert_vits2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
logger.error(f"{k} is not in the checkpoint")
logger.warning(
f"If you're using a transitional version, please add the \"version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.")
elif "en_bert_proj" in k:
v = torch.zeros_like(v)
if version is None:
logger.error(f"{k} is not in the checkpoint")
logger.warning(
f"If you're using an older version of the model, consider adding the \"version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.1.1\"")
else:
logger.error(f"{k} is not in the checkpoint")

Expand Down

0 comments on commit 1799182

Please sign in to comment.