Reapply pre-commit hooks

This commit is contained in:
huanglizhuo
2025-03-29 20:58:42 +09:00
parent f34465d118
commit eaa7fd8a01

View File

@@ -551,34 +551,81 @@ Have a conversation with an AI using your reference voice!
)
if not USING_SPACES:
model_name_input = gr.Textbox(
label="Chat Model Name",
value="Qwen/Qwen2.5-3B-Instruct",
info="Enter the name of a HuggingFace chat model",
)
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
chat_interface_container = gr.Column(visible=False)
@gpu_decorator
def load_chat_model():
def load_chat_model(model_name):
global chat_model_state, chat_tokenizer_state
if chat_model_state is None:
show_info = gr.Info
show_info("Loading chat model...")
model_name = "Qwen/Qwen2.5-3B-Instruct"
chat_model_state = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto"
)
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
show_info("Chat model loaded.")
# Always reload model when button is clicked
if chat_model_state is not None:
print(f"Unloading previous model and loading new model: {model_name}")
# Clear previous model from memory
chat_model_state = None
chat_tokenizer_state = None
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
show_info = gr.Info
show_info("Loading chat model...")
print(f"Loading chat model: {model_name}")
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
print(f"Model {model_name} loaded successfully!")
show_info("Chat model loaded.")
return gr.update(visible=False), gr.update(visible=True)
load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
load_chat_model_btn.click(
load_chat_model, inputs=[model_name_input], outputs=[load_chat_model_btn, chat_interface_container]
)
else:
model_name_input = gr.Textbox(
label="Chat Model Name",
value="Qwen/Qwen2.5-3B-Instruct",
info="Enter the name of a HuggingFace chat model",
)
chat_interface_container = gr.Column()
if chat_model_state is None:
model_name = "Qwen/Qwen2.5-3B-Instruct"
@gpu_decorator
def load_spaces_model(model_name):
global chat_model_state, chat_tokenizer_state
# Always reload model when called
if chat_model_state is not None:
print(f"Unloading previous Spaces model and loading new model: {model_name}")
# Clear previous model from memory
chat_model_state = None
chat_tokenizer_state = None
import gc
import torch
gc.collect()
torch.cuda.empty_cache()
print(f"Loading chat model in Spaces: {model_name}")
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
print(f"Model {model_name} loaded successfully in Spaces!")
return True
# Load model when model name is changed
model_name_input.change(
load_spaces_model,
inputs=[model_name_input],
outputs=[],
)
# Initialize with default model
load_spaces_model("Qwen/Qwen2.5-3B-Instruct")
with chat_interface_container:
with gr.Row():