Skip to content

Commit 8c96081

Browse files
committed
update
1 parent dd2d926 commit 8c96081

File tree

5 files changed

+13
-6
lines changed

5 files changed

+13
-6
lines changed

cosyvoice/bin/export_onnx.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
2828
sys.path.append('{}/../..'.format(ROOT_DIR))
2929
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30-
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
30+
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2, CosyVoice3
3131
from cosyvoice.utils.file_utils import logging
3232

3333

@@ -64,7 +64,10 @@ def main():
6464
try:
6565
model = CosyVoice2(args.model_dir)
6666
except Exception:
67-
raise TypeError('no valid model_type!')
67+
try:
68+
model = CosyVoice3(args.model_dir)
69+
except Exception:
70+
raise TypeError('no valid model_type!')
6871

6972
# 1. export flow decoder estimator
7073
estimator = model.model.flow.decoder.estimator

cosyvoice/cli/cosyvoice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, f
221221
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
222222
self.model.load('{}/llm.pt'.format(model_dir),
223223
'{}/flow.pt'.format(model_dir),
224-
'{}/bigvgan.pt'.format(model_dir))
224+
'{}/hift.pt'.format(model_dir))
225225
if load_vllm:
226226
self.model.load_vllm('{}/vllm'.format(model_dir))
227227
if load_jit:

cosyvoice/cli/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, u
447447
if speed != 1.0:
448448
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
449449
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
450-
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
450+
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel)
451451
if self.hift_cache_dict[uuid] is not None:
452452
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
453453
return tts_speech

cosyvoice/flow/DiT/dit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
mu_dim=None,
116116
long_skip_connection=False,
117117
spk_dim=None,
118+
out_channels=None,
118119
static_chunk_size=50,
119120
num_decoding_left_chunks=2
120121
):
@@ -137,6 +138,7 @@ def __init__(
137138

138139
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
139140
self.proj_out = nn.Linear(dim, mel_dim)
141+
self.out_channels = out_channels
140142
self.static_chunk_size = static_chunk_size
141143
self.num_decoding_left_chunks = num_decoding_left_chunks
142144

cosyvoice/utils/class_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
3434
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
3535
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
36-
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
37-
from cosyvoice.hifigan.generator import HiFTGenerator
36+
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
37+
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
3838
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
3939

4040

@@ -80,4 +80,6 @@ def get_model_type(configs):
8080
return CosyVoiceModel
8181
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
8282
return CosyVoice2Model
83+
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
84+
return CosyVoice2Model
8385
raise TypeError('No valid model type found!')

0 commit comments

Comments
 (0)