diff --git a/vits/vits.py b/vits/vits.py index f0fd777..45e6cec 100644 --- a/vits/vits.py +++ b/vits/vits.py @@ -77,7 +77,7 @@ def infer(self, params): x_tst_lengths = LongTensor([params.get("stn_tst").size(0)]).to(self.device) x_tst_prosody = torch.FloatTensor(params.get("char_embeds")).unsqueeze(0).to( self.device) if self.bert_embedding else None - sid = params.get("sid").to(self.device) if not self.bert_embedding else None + sid = params.get("sid").to(self.device) emotion = params.get("emotion").to(self.device) if self.emotion_embedding else None audio = self.net_g_ms.infer(x=x_tst, @@ -100,10 +100,9 @@ def get_infer_param(self, length_scale, noise_scale, noise_scale_w, text=None, s if self.model_type != "hubert": if self.bert_embedding: stn_tst, char_embeds = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned) - sid = None else: stn_tst = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned) - sid = LongTensor([speaker_id]) + sid = LongTensor([speaker_id]) if self.model_type == "w2v2": # if emotion_reference.endswith('.npy'):