mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-12 15:50:07 -08:00
reorganize and distinguish behavior from local and space
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# ruff: noqa: E402
|
||||
# Above allows ruff to ignore E402: module level import not at top of file
|
||||
|
||||
import gc
|
||||
import json
|
||||
import re
|
||||
import tempfile
|
||||
@@ -11,6 +12,7 @@ import click
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
from cached_path import cached_path
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -550,82 +552,47 @@ Have a conversation with an AI using your reference voice!
|
||||
"""
|
||||
)
|
||||
|
||||
if not USING_SPACES:
|
||||
model_name_input = gr.Textbox(
|
||||
label="Chat Model Name",
|
||||
value="Qwen/Qwen2.5-3B-Instruct",
|
||||
info="Enter the name of a HuggingFace chat model",
|
||||
)
|
||||
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
|
||||
chat_model_name_list = ["Qwen/Qwen2.5-3B-Instruct",]
|
||||
|
||||
chat_interface_container = gr.Column(visible=False)
|
||||
@gpu_decorator
|
||||
def load_chat_model(chat_model_name):
|
||||
show_info = gr.Info
|
||||
global chat_model_state, chat_tokenizer_state
|
||||
if chat_model_state is not None:
|
||||
chat_model_state = None
|
||||
chat_tokenizer_state = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@gpu_decorator
|
||||
def load_chat_model(model_name):
|
||||
global chat_model_state, chat_tokenizer_state
|
||||
# Always reload model when button is clicked
|
||||
if chat_model_state is not None:
|
||||
print(f"Unloading previous model and loading new model: {model_name}")
|
||||
# Clear previous model from memory
|
||||
chat_model_state = None
|
||||
chat_tokenizer_state = None
|
||||
import gc
|
||||
import torch
|
||||
show_info(f"Loading chat model: {chat_model_name}")
|
||||
chat_model_state = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype="auto", device_map="auto")
|
||||
chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
|
||||
show_info(f"Chat model {chat_model_name} loaded successfully!")
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return gr.update(visible=False), gr.update(visible=True)
|
||||
|
||||
show_info = gr.Info
|
||||
show_info("Loading chat model...")
|
||||
print(f"Loading chat model: {model_name}")
|
||||
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
||||
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
|
||||
print(f"Model {model_name} loaded successfully!")
|
||||
show_info("Chat model loaded.")
|
||||
if USING_SPACES:
|
||||
load_chat_model(chat_model_name_list[0])
|
||||
|
||||
return gr.update(visible=False), gr.update(visible=True)
|
||||
chat_model_name_input = gr.Dropdown(
|
||||
choices=chat_model_name_list,
|
||||
value=chat_model_name_list[0],
|
||||
label="Chat Model Name",
|
||||
info="Enter the name of a HuggingFace chat model",
|
||||
allow_custom_value=not USING_SPACES,
|
||||
)
|
||||
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary", visible=not USING_SPACES)
|
||||
chat_interface_container = gr.Column(visible=USING_SPACES)
|
||||
|
||||
load_chat_model_btn.click(
|
||||
load_chat_model, inputs=[model_name_input], outputs=[load_chat_model_btn, chat_interface_container]
|
||||
)
|
||||
|
||||
else:
|
||||
model_name_input = gr.Textbox(
|
||||
label="Chat Model Name",
|
||||
value="Qwen/Qwen2.5-3B-Instruct",
|
||||
info="Enter the name of a HuggingFace chat model",
|
||||
)
|
||||
chat_interface_container = gr.Column()
|
||||
|
||||
@gpu_decorator
|
||||
def load_spaces_model(model_name):
|
||||
global chat_model_state, chat_tokenizer_state
|
||||
# Always reload model when called
|
||||
if chat_model_state is not None:
|
||||
print(f"Unloading previous Spaces model and loading new model: {model_name}")
|
||||
# Clear previous model from memory
|
||||
chat_model_state = None
|
||||
chat_tokenizer_state = None
|
||||
import gc
|
||||
import torch
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(f"Loading chat model in Spaces: {model_name}")
|
||||
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
||||
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
|
||||
print(f"Model {model_name} loaded successfully in Spaces!")
|
||||
return True
|
||||
|
||||
# Load model when model name is changed
|
||||
model_name_input.change(
|
||||
load_spaces_model,
|
||||
inputs=[model_name_input],
|
||||
outputs=[],
|
||||
)
|
||||
# Initialize with default model
|
||||
load_spaces_model("Qwen/Qwen2.5-3B-Instruct")
|
||||
chat_model_name_input.change(
|
||||
lambda: gr.update(visible=True),
|
||||
None,
|
||||
load_chat_model_btn,
|
||||
show_progress="hidden",
|
||||
)
|
||||
load_chat_model_btn.click(
|
||||
load_chat_model, inputs=[chat_model_name_input], outputs=[load_chat_model_btn, chat_interface_container]
|
||||
)
|
||||
|
||||
with chat_interface_container:
|
||||
with gr.Row():
|
||||
|
||||
Reference in New Issue
Block a user