Update infer_gradio.py. Enable seed selecting for multistyle generation

This commit is contained in:
SWivid
2025-05-05 00:58:24 +08:00
parent e6fee5e9ba
commit 818b868fab

View File

@@ -129,7 +129,7 @@ def infer(
gen_text,
model,
remove_silence,
seed,
seed=None,
cross_fade_duration=0.15,
nfe_step=32,
speed=1,
@@ -140,7 +140,13 @@ def infer(
return gr.update(), gr.update(), ref_text
# Set inference seed
if seed is None:
seed = np.random.randint(0, 2**31 - 1)
elif seed < 0 or seed > 2**31 - 1:
gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
seed = np.random.randint(0, 2**31 - 1)
torch.manual_seed(seed)
used_seed = seed
if not gen_text.strip():
gr.Warning("Please enter text to generate or upload a text file.")
@@ -191,7 +197,7 @@ def infer(
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path, ref_text
return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
with gr.Blocks() as app_credits:
@@ -277,27 +283,21 @@ with gr.Blocks() as app_tts:
nfe_slider,
speed_slider,
):
# Determine the seed to use
if randomize_seed:
seed = np.random.randint(0, 2**31 - 1)
else:
seed = seed_input
if seed < 0 or seed > 2**31 - 1:
gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
seed = np.random.randint(0, 2**31 - 1)
seed_input = None
audio_out, spectrogram_path, ref_text_out = infer(
audio_out, spectrogram_path, ref_text_out, used_seed = infer(
ref_audio_input,
ref_text_input,
gen_text_input,
tts_model_choice,
remove_silence,
seed=seed,
seed=seed_input,
cross_fade_duration=cross_fade_duration_slider,
nfe_step=nfe_slider,
speed=speed_slider,
)
return audio_out, spectrogram_path, ref_text_out, seed
return audio_out, spectrogram_path, ref_text_out, used_seed
gen_text_file.upload(
load_text_from_file,
@@ -329,26 +329,34 @@ with gr.Blocks() as app_tts:
def parse_speechtypes_text(gen_text):
# Pattern to find {speechtype}
pattern = r"\{(.*?)\}"
# Pattern to find {str} or {"name": str, "seed": int, "speed": float}
pattern = r"(\{.*?\})"
# Split the text by the pattern
tokens = re.split(pattern, gen_text)
segments = []
current_style = "Regular"
current_type_dict = {
"name": "Regular",
"seed": -1,
"speed": 1.0,
}
for i in range(len(tokens)):
if i % 2 == 0:
# This is text
text = tokens[i].strip()
if text:
segments.append({"style": current_style, "text": text})
current_type_dict["text"] = text
segments.append(current_type_dict)
else:
# This is style
style = tokens[i].strip()
current_style = style
# This is type
type_str = tokens[i].strip()
try: # if type dict
current_type_dict = json.loads(type_str)
except json.decoder.JSONDecodeError:
current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
return segments
@@ -366,41 +374,48 @@ with gr.Blocks() as app_multistyle:
with gr.Row():
gr.Markdown(
"""
**Example Input:**
{Regular} Hello, I'd like to order a sandwich please.
{Surprised} What do you mean you're out of bread?
{Sad} I really wanted a sandwich though...
{Angry} You know what, darn you and your little shop!
{Whisper} I'll just go back home and cry now.
**Example Input:** <br>
{Regular} Hello, I'd like to order a sandwich please. <br>
{Surprised} What do you mean you're out of bread? <br>
{Sad} I really wanted a sandwich though... <br>
{Angry} You know what, darn you and your little shop! <br>
{Whisper} I'll just go back home and cry now. <br>
{Shouting} Why me?!
"""
)
gr.Markdown(
"""
**Example Input 2:**
{Speaker1_Happy} Hello, I'd like to order a sandwich please.
{Speaker2_Regular} Sorry, we're out of bread.
{Speaker1_Sad} I really wanted a sandwich though...
{Speaker2_Whisper} I'll give you the last one I was hiding.
**Example Input 2:** <br>
{"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please. <br>
{"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread. <br>
{"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though... <br>
{"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding.
"""
)
gr.Markdown(
"Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.'
)
# Regular speech type (mandatory)
with gr.Row() as regular_row:
with gr.Row(variant="compact") as regular_row:
with gr.Column(scale=1, min_width=160):
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
regular_insert = gr.Button("Insert Label", variant="secondary")
with gr.Column(scale=3):
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
with gr.Column(scale=3):
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=8, scale=3)
with gr.Column(scale=1):
regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
with gr.Row():
regular_seed_slider = gr.Slider(
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
)
regular_speed_slider = gr.Slider(
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
)
with gr.Column(scale=1, min_width=160):
regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
# Regular speech type (max 100)
max_speech_types = 100
@@ -409,32 +424,54 @@ with gr.Blocks() as app_multistyle:
speech_type_audios = [regular_audio]
speech_type_ref_texts = [regular_ref_text]
speech_type_ref_text_files = [regular_ref_text_file]
speech_type_seeds = [regular_seed_slider]
speech_type_speeds = [regular_speed_slider]
speech_type_delete_btns = [None]
speech_type_insert_btns = [regular_insert]
# Additional speech types (99 more)
for i in range(max_speech_types - 1):
with gr.Row(visible=False) as row:
with gr.Row(variant="compact", visible=False) as row:
with gr.Column(scale=1, min_width=160):
name_input = gr.Textbox(label="Speech Type Name")
delete_btn = gr.Button("Delete Type", variant="secondary")
insert_btn = gr.Button("Insert Label", variant="secondary")
delete_btn = gr.Button("Delete Type", variant="stop")
with gr.Column(scale=3):
audio_input = gr.Audio(label="Reference Audio", type="filepath")
with gr.Column(scale=3):
ref_text_input = gr.Textbox(label="Reference Text", lines=8, scale=3)
with gr.Column(scale=1):
ref_text_file_input = gr.File(
label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
)
ref_text_input = gr.Textbox(label="Reference Text", lines=4)
with gr.Row():
seed_input = gr.Slider(
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
)
speed_input = gr.Slider(
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
)
with gr.Column(scale=1, min_width=160):
ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
speech_type_rows.append(row)
speech_type_names.append(name_input)
speech_type_audios.append(audio_input)
speech_type_ref_texts.append(ref_text_input)
speech_type_ref_text_files.append(ref_text_file_input)
speech_type_seeds.append(seed_input)
speech_type_speeds.append(speed_input)
speech_type_delete_btns.append(delete_btn)
speech_type_insert_btns.append(insert_btn)
# Global logic for all speech types
for i in range(max_speech_types):
speech_type_audios[i].clear(
lambda: [None, None],
None,
[speech_type_ref_texts[i], speech_type_ref_text_files[i]],
)
speech_type_ref_text_files[i].upload(
load_text_from_file,
inputs=[speech_type_ref_text_files[i]],
outputs=[speech_type_ref_texts[i]],
)
# Button to add speech type
add_speech_type_btn = gr.Button("Add Speech Type")
@@ -470,18 +507,6 @@ with gr.Blocks() as app_multistyle:
speech_type_ref_text_files[i],
],
)
speech_type_ref_text_files[i].upload(
load_text_from_file,
inputs=[speech_type_ref_text_files[i]],
outputs=[speech_type_ref_texts[i]],
)
# Update regular speech type ref text file
regular_ref_text_file.upload(
load_text_from_file,
inputs=[regular_ref_text_file],
outputs=[regular_ref_text],
)
# Text input for the prompt
with gr.Row():
@@ -495,10 +520,17 @@ with gr.Blocks() as app_multistyle:
gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
def make_insert_speech_type_fn(index):
def insert_speech_type_fn(current_text, speech_type_name):
def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
current_text = current_text or ""
speech_type_name = speech_type_name or "None"
updated_text = current_text + f"{{{speech_type_name}}} "
if not speech_type_name:
gr.Warning("Please enter speech type name before insert.")
return current_text
speech_type_dict = {
"name": speech_type_name,
"seed": speech_type_seed,
"speed": speech_type_speed,
}
updated_text = current_text + json.dumps(speech_type_dict) + " "
return updated_text
return insert_speech_type_fn
@@ -507,16 +539,24 @@ with gr.Blocks() as app_multistyle:
insert_fn = make_insert_speech_type_fn(i)
insert_btn.click(
insert_fn,
inputs=[gen_text_input_multistyle, speech_type_names[i]],
inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
outputs=gen_text_input_multistyle,
)
with gr.Accordion("Advanced Settings", open=False):
remove_silence_multistyle = gr.Checkbox(
label="Remove Silences",
info="Turn on to automatically detect and crop long silences.",
value=True,
)
with gr.Accordion("Advanced Settings", open=True):
with gr.Row():
with gr.Column():
show_cherrypick_multistyle = gr.Checkbox(
label="Show Cherry-pick Interface",
info="Turn on to show interface, picking seeds from previous generations.",
value=False,
)
with gr.Column():
remove_silence_multistyle = gr.Checkbox(
label="Remove Silences",
info="Turn on to automatically detect and crop long silences.",
value=True,
)
# Generate button
generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
@@ -524,6 +564,24 @@ with gr.Blocks() as app_multistyle:
# Output audio
audio_output_multistyle = gr.Audio(label="Synthesized Audio")
# Used seed gallery
cherrypick_interface_multistyle = gr.Textbox(
label="Cherry-pick Interface",
lines=10,
max_lines=40,
show_copy_button=True,
interactive=False,
visible=False,
)
# Logic control to show/hide the cherrypick interface
show_cherrypick_multistyle.change(
lambda is_visible: gr.update(visible=is_visible),
show_cherrypick_multistyle,
cherrypick_interface_multistyle,
)
# Function to load text to generate from file
gen_text_file_multistyle.upload(
load_text_from_file,
inputs=[gen_text_file_multistyle],
@@ -557,44 +615,60 @@ with gr.Blocks() as app_multistyle:
# For each segment, generate speech
generated_audio_segments = []
current_style = "Regular"
current_type_name = "Regular"
inference_meta_data = ""
for segment in segments:
style = segment["style"]
name = segment["name"]
seed = segment["seed"]
speed = segment["speed"]
text = segment["text"]
if style in speech_types:
current_style = style
if name in speech_types:
current_type_name = name
else:
gr.Warning(f"Type {style} is not available, will use Regular as default.")
current_style = "Regular"
gr.Warning(f"Type {name} is not available, will use Regular as default.")
current_type_name = "Regular"
try:
ref_audio = speech_types[current_style]["audio"]
ref_audio = speech_types[current_type_name]["audio"]
except KeyError:
gr.Warning(f"Please provide reference audio for type {current_style}.")
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
ref_text = speech_types[current_style].get("ref_text", "")
gr.Warning(f"Please provide reference audio for type {current_type_name}.")
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
ref_text = speech_types[current_type_name].get("ref_text", "")
# TODO. Attribute each type a unique seed (maybe also speed, pseudo-feature for #730 #813)
seed = np.random.randint(0, 2**31 - 1)
if seed == -1:
seed_input = None
# Generate speech for this segment
audio_out, _, ref_text_out = infer(
ref_audio, ref_text, text, tts_model_choice, remove_silence, seed, 0, show_info=print
audio_out, _, ref_text_out, used_seed = infer(
ref_audio,
ref_text,
text,
tts_model_choice,
remove_silence,
seed=seed_input,
cross_fade_duration=0,
speed=speed,
show_info=print,
) # show_info=print no pull to top when generating
sr, audio_data = audio_out
generated_audio_segments.append(audio_data)
speech_types[current_style]["ref_text"] = ref_text_out
speech_types[current_type_name]["ref_text"] = ref_text_out
inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
# Concatenate all audio segments
if generated_audio_segments:
final_audio_data = np.concatenate(generated_audio_segments)
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
return (
[(sr, final_audio_data)]
+ [speech_types[name]["ref_text"] for name in speech_types]
+ [inference_meta_data]
)
else:
gr.Warning("No audio generated.")
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
generate_multistyle_btn.click(
generate_multistyle_speech,
@@ -607,7 +681,7 @@ with gr.Blocks() as app_multistyle:
+ [
remove_silence_multistyle,
],
outputs=[audio_output_multistyle] + speech_type_ref_texts,
outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
)
# Validation function to disable Generate button if speech types are missing
@@ -624,7 +698,7 @@ with gr.Blocks() as app_multistyle:
# Parse the gen_text to get the speech types used
segments = parse_speechtypes_text(gen_text)
speech_types_in_text = set(segment["style"] for segment in segments)
speech_types_in_text = set(segment["name"] for segment in segments)
# Check if all speech types in text are available
missing_speech_types = speech_types_in_text - speech_types_available
@@ -788,27 +862,21 @@ Have a conversation with an AI using your reference voice!
if not last_ai_response or conv_state[-1]["role"] != "assistant":
return None, ref_text, seed_input
# Determine the seed to use
if randomize_seed:
seed = np.random.randint(0, 2**31 - 1)
else:
seed = seed_input
if seed < 0 or seed > 2**31 - 1:
gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
seed = np.random.randint(0, 2**31 - 1)
seed_input = None
audio_result, _, ref_text_out = infer(
audio_result, _, ref_text_out, used_seed = infer(
ref_audio,
ref_text,
last_ai_response,
tts_model_choice,
remove_silence,
seed=seed,
seed=seed_input,
cross_fade_duration=0.15,
speed=1.0,
show_info=print, # show_info=print no pull to top when generating
)
return audio_result, ref_text_out, seed
return audio_result, ref_text_out, used_seed
def clear_conversation():
"""Reset the conversation"""