Added intel XPU support

This commit is contained in:
98440
2025-01-20 00:47:57 +08:00
parent 9e51878d18
commit 81ce1d8670
7 changed files with 31 additions and 6 deletions

View File

@@ -32,6 +32,10 @@ pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https:/
# AMD GPU: install pytorch with your ROCm version, e.g.
pip install torch==2.5.1+rocm6.2 torchaudio==2.5.1+rocm6.2 --extra-index-url https://download.pytorch.org/whl/rocm6.2
# intel GPU: install pytorch with your XPU version, e.g.
# Intel® Deep Learning Essentials or Intel® oneAPI Base Toolkit must be installed
pip install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu
```
Then you can choose from a few options below:

View File

@@ -47,7 +47,7 @@ class F5TTS:
else:
import torch
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
self.device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Load models
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)

View File

@@ -13,7 +13,7 @@ def main():
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
predictor = predictor.to(device)

View File

@@ -10,7 +10,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectro
from f5_tts.model import CFM, DiT, UNetT
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# --------------------- Dataset Settings -------------------- #

View File

@@ -33,7 +33,7 @@ from f5_tts.model.utils import (
_ref_audio_cache = {}
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# -----------------------------------------

View File

@@ -17,7 +17,7 @@ from model.backbones.dit import DiT
class TTSStreamingProcessor:
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
"cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
# Load the model using the provided checkpoint and vocab files

View File

@@ -46,7 +46,7 @@ path_data = str(files("f5_tts").joinpath("../../data"))
path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Save settings from a JSON file
@@ -889,6 +889,13 @@ def calculate_train(
gpu_properties = torch.cuda.get_device_properties(i)
total_memory += gpu_properties.total_memory / (1024**3) # in GB
elif torch.xpu.is_available():
gpu_count = torch.xpu.device_count()
total_memory = 0
for i in range(gpu_count):
gpu_properties = torch.xpu.get_device_properties(i)
total_memory += gpu_properties.total_memory / (1024**3)
elif torch.backends.mps.is_available():
gpu_count = 1
total_memory = psutil.virtual_memory().available / (1024**3)
@@ -1284,7 +1291,21 @@ def get_gpu_stats():
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
)
elif torch.xpu.is_available():
gpu_count = torch.xpu.device_count()
for i in range(gpu_count):
gpu_name = torch.xpu.get_device_name(i)
gpu_properties = torch.xpu.get_device_properties(i)
total_memory = gpu_properties.total_memory / (1024**3) # in GB
allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB
reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB
gpu_stats += (
f"GPU {i} Name: {gpu_name}\n"
f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
)
elif torch.backends.mps.is_available():
gpu_count = 1
gpu_stats += "MPS GPU\n"