This repository was archived by the owner on Nov 21, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 479
Expand file tree
/
Copy pathinference_v2.py
More file actions
144 lines (118 loc) · 5.59 KB
/
inference_v2.py
File metadata and controls
144 lines (118 loc) · 5.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import argparse
import torch
import yaml
import soundfile as sf
import time
from modules.commons import str2bool
# Set up device and torch configurations
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
dtype = torch.float16
# Global variables to store model instances
vc_wrapper_v2 = None
def load_v2_models(args):
"""Load V2 models using the wrapper from app.py"""
from hydra.utils import instantiate
from omegaconf import DictConfig
cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
vc_wrapper = instantiate(cfg)
vc_wrapper.load_checkpoints(ar_checkpoint_path=args.ar_checkpoint_path,
cfm_checkpoint_path=args.cfm_checkpoint_path)
vc_wrapper.to(device)
vc_wrapper.eval()
vc_wrapper.setup_ar_caches(max_batch_size=1, max_seq_len=4096, dtype=dtype, device=device)
if args.compile:
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
if hasattr(torch._inductor.config, "fx_graph_cache"):
# Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.fx_graph_cache = True
vc_wrapper.compile_ar()
# vc_wrapper.compile_cfm()
return vc_wrapper
def convert_voice_v2(source_audio_path, target_audio_path, args):
"""Convert voice using V2 model"""
global vc_wrapper_v2
if vc_wrapper_v2 is None:
vc_wrapper_v2 = load_v2_models(args)
# Use the generator function but collect all outputs
generator = vc_wrapper_v2.convert_voice_with_streaming(
source_audio_path=source_audio_path,
target_audio_path=target_audio_path,
diffusion_steps=args.diffusion_steps,
length_adjust=args.length_adjust,
intelligebility_cfg_rate=args.intelligibility_cfg_rate,
similarity_cfg_rate=args.similarity_cfg_rate,
top_p=args.top_p,
temperature=args.temperature,
repetition_penalty=args.repetition_penalty,
convert_style=args.convert_style,
anonymization_only=args.anonymization_only,
device=device,
dtype=dtype,
stream_output=True
)
# Collect all outputs from the generator
for output in generator:
_, full_audio = output
return full_audio
def main(args):
# Create output directory if it doesn't exist
os.makedirs(args.output, exist_ok=True)
start_time = time.time()
converted_audio = convert_voice_v2(args.source, args.target, args)
end_time = time.time()
if converted_audio is None:
print("Error: Failed to convert voice")
return
# Save the converted audio
source_name = os.path.basename(args.source).split(".")[0]
target_name = os.path.basename(args.target).split(".")[0]
# Create a descriptive filename
filename = f"vc_v2_{source_name}_{target_name}_{args.length_adjust}_{args.diffusion_steps}_{args.similarity_cfg_rate}.wav"
output_path = os.path.join(args.output, filename)
save_sr, converted_audio = converted_audio
sf.write(output_path, converted_audio, save_sr)
print(f"Voice conversion completed in {end_time - start_time:.2f} seconds")
print(f"Output saved to: {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Voice Conversion Inference Script")
parser.add_argument("--source", type=str, required=True,
help="Path to source audio file")
parser.add_argument("--target", type=str, required=True,
help="Path to target/reference audio file")
parser.add_argument("--output", type=str, default="./output",
help="Output directory for converted audio")
parser.add_argument("--diffusion-steps", type=int, default=30,
help="Number of diffusion steps")
parser.add_argument("--length-adjust", type=float, default=1.0,
help="Length adjustment factor (<1.0 for speed-up, >1.0 for slow-down)")
parser.add_argument("--compile", type=bool, default=False,
help="Whether to compile the model for faster inference")
# V2 specific arguments
parser.add_argument("--intelligibility-cfg-rate", type=float, default=0.7,
help="Intelligibility CFG rate for V2 model")
parser.add_argument("--similarity-cfg-rate", type=float, default=0.7,
help="Similarity CFG rate for V2 model")
parser.add_argument("--top-p", type=float, default=0.9,
help="Top-p sampling parameter for V2 model")
parser.add_argument("--temperature", type=float, default=1.0,
help="Temperature sampling parameter for V2 model")
parser.add_argument("--repetition-penalty", type=float, default=1.0,
help="Repetition penalty for V2 model")
parser.add_argument("--convert-style", type=str2bool, default=False,
help="Convert style/emotion/accent for V2 model")
parser.add_argument("--anonymization-only", type=str2bool, default=False,
help="Anonymization only mode for V2 model")
# V2 custom checkpoints
parser.add_argument("--ar-checkpoint-path", type=str, default=None,
help="Path to custom checkpoint file")
parser.add_argument("--cfm-checkpoint-path", type=str, default=None,
help="Path to custom checkpoint file")
args = parser.parse_args()
main(args)