Optimize inference (#990)
* [feature]add dataset classs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [dev]combine agent and tts infer * [feature]:update inference * [feature]:update uv.lock * [Merge]:merge upstream/main * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix]:remove unused files * [fix]:remove unused files * [fix]:remove unused files * [fix]:fix infer bugs --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3d31a80ad1
commit
4bf24d8c33
@ -976,49 +976,3 @@ class DAC(BaseModel, CodecMixin):
|
||||
z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
|
||||
x = self.decode(z)
|
||||
return x[..., :length], vq_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def filter_state_dict_shapes(params, model):
|
||||
model_state_dict = model.state_dict()
|
||||
filtered_state_dict = {
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in model_state_dict and v.shape == model_state_dict[k].shape
|
||||
}
|
||||
skipped_keys = set(params.keys()) - set(filtered_state_dict.keys())
|
||||
if skipped_keys:
|
||||
print(
|
||||
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
|
||||
)
|
||||
return filtered_state_dict, skipped_keys
|
||||
|
||||
model = hydra.utils.instantiate(
|
||||
OmegaConf.load("fish_speech/configs/modded_dac_vq.yaml")
|
||||
)
|
||||
sd = torch.load("checkpoints/openaudio-s1-mini/firefly-gan-large.pth")
|
||||
filtered_sd, skipped_keys = filter_state_dict_shapes(sd, model)
|
||||
print(f"Skipped keys: {skipped_keys}")
|
||||
model.load_state_dict(filtered_sd, strict=False)
|
||||
model.eval()
|
||||
|
||||
src_audio_path = "./test.wav"
|
||||
wave_np, _ = librosa.load(src_audio_path, sr=44100, mono=False)
|
||||
if len(wave_np.shape) == 1:
|
||||
wave_np = wave_np[None, :]
|
||||
wave_tensor = torch.from_numpy(wave_np).unsqueeze(1)
|
||||
|
||||
with torch.no_grad():
|
||||
# encode 返回 (indices, indices_lens)
|
||||
indices, indices_lens = model.encode(wave_tensor)
|
||||
print(f"Indices shape: {indices.shape}")
|
||||
print(f"Indices lengths: {indices_lens}")
|
||||
|
||||
# decode 需要 indices 和 feature_lengths 两个参数
|
||||
fake_audio, audio_lengths = model.decode(indices, indices_lens)
|
||||
print(f"Decoded audio shape: {fake_audio.shape}")
|
||||
print(f"Audio lengths: {audio_lengths}")
|
||||
|
||||
# 保存重建的音频
|
||||
sf.write("fake.wav", fake_audio.squeeze(1).cpu().numpy().T, 44100)
|
||||
|
@ -10,7 +10,6 @@ from typing import Literal, Optional, Tuple, Union
|
||||
import click
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
import torch._inductor.config
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
@ -21,9 +20,8 @@ from fish_speech.content_sequence import (
|
||||
TextPart,
|
||||
VQPart,
|
||||
)
|
||||
from fish_speech.models.text2semantic.llama import BaseModelArgs
|
||||
from fish_speech.text import clean_text, split_text
|
||||
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
||||
from fish_speech.text import split_text
|
||||
from fish_speech.tokenizer import IM_END_TOKEN
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
@ -37,7 +35,6 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
from fish_speech.models.text2semantic.llama import (
|
||||
BaseTransformer,
|
||||
DualARTransformer,
|
||||
NaiveTransformer,
|
||||
)
|
||||
@ -98,16 +95,27 @@ def decode_one_token_ar(
|
||||
model: DualARTransformer,
|
||||
x: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
semantic_ids: list,
|
||||
previous_tokens: torch.Tensor = None,
|
||||
**sampling_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate one token using dual autoregressive transformer for text-to-speech.
|
||||
|
||||
First generates semantic tokens, then generates acoustic codebook tokens sequentially.
|
||||
|
||||
Args:
|
||||
x: Input token tensor (1, num_codebooks+1, seq_len)
|
||||
input_pos: Position indices for input tokens (seq_len,)
|
||||
temperature/top_p/repetition_penalty: Sampling parameters (1, 1)
|
||||
previous_tokens: Previous tokens for repetition penalty (1, num_codebooks+1, history_seq_len)
|
||||
audio_masks/audio_parts: Audio conditioning tensors (num_codebooks, seq_len)
|
||||
|
||||
Returns:
|
||||
Generated tokens tensor (num_codebooks+1, 1) - one token per codebook
|
||||
"""
|
||||
x = model.forward_generate(x, input_pos)
|
||||
|
||||
sampling_kwargs_main = sampling_kwargs.copy()
|
||||
# sampling_kwargs_main["temperature"] = 0.1
|
||||
# sampling_kwargs_main["top_p"] = 0.1
|
||||
# sampling_kwargs_main["repetition_penalty"] = 1.0
|
||||
|
||||
codebooks = [
|
||||
sample(
|
||||
@ -152,12 +160,7 @@ def decode_one_token_ar(
|
||||
codebooks.append(a)
|
||||
|
||||
codebooks = torch.stack(codebooks, dim=0)
|
||||
# semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
||||
# codebooks[1:, :] = torch.masked_fill(
|
||||
# codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
|
||||
# )
|
||||
|
||||
# print(codebooks)
|
||||
return codebooks
|
||||
|
||||
|
||||
@ -166,10 +169,24 @@ def decode_n_tokens(
|
||||
cur_token: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
num_new_tokens: int,
|
||||
semantic_ids: list,
|
||||
decode_one_token=decode_one_token_ar,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
"""
|
||||
Generate n tokens iteratively using the model.
|
||||
|
||||
Args:
|
||||
model: The transformer model
|
||||
cur_token: Current token tensor of shape (1, num_codebooks+1, seq_len)
|
||||
input_pos: Current input position tensor
|
||||
num_new_tokens: Number of new tokens to generate
|
||||
semantic_ids: List of semantic token IDs
|
||||
decode_one_token: Function to decode one token
|
||||
**sampling_kwargs: Additional sampling parameters
|
||||
|
||||
Returns:
|
||||
Generated tokens tensor of shape (num_codebooks+1, generated_len)
|
||||
"""
|
||||
previous_tokens = torch.zeros(
|
||||
(model.config.num_codebooks + 1, model.config.max_seq_len),
|
||||
dtype=torch.int,
|
||||
@ -184,21 +201,14 @@ def decode_n_tokens(
|
||||
else:
|
||||
window = previous_tokens[:, i - win_size : i]
|
||||
|
||||
with (
|
||||
torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else nullcontext()
|
||||
): # Actually better for Inductor to codegen attention here
|
||||
with sdpa_kernel(SDPBackend.MATH):
|
||||
next_token = decode_one_token(
|
||||
model=model,
|
||||
x=cur_token,
|
||||
input_pos=input_pos,
|
||||
previous_tokens=window,
|
||||
semantic_ids=semantic_ids,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
).clone()
|
||||
|
||||
input_pos += 1
|
||||
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
||||
@ -223,15 +233,21 @@ def generate(
|
||||
**sampling_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
||||
Generate tokens from text prompt using the transformer model.
|
||||
|
||||
Args:
|
||||
model: The transformer model for generation
|
||||
prompt: Input token tensor of shape (num_codebooks+1, seq_len)
|
||||
max_new_tokens: Maximum number of new tokens to generate
|
||||
decode_one_token: Function to decode one token at a time
|
||||
**sampling_kwargs: Additional sampling parameters (temperature, top_p, repetition_penalty)
|
||||
|
||||
Returns:
|
||||
Generated sequence tensor of shape (num_codebooks+1, total_seq_len)
|
||||
where total_seq_len = original_seq_len + generated_tokens_len
|
||||
"""
|
||||
|
||||
# create an empty tensor of the expected final shape and fill in the current tokens
|
||||
T = prompt.size(1)
|
||||
# semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
||||
semantic_ids = [
|
||||
model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
|
||||
]
|
||||
|
||||
if max_new_tokens:
|
||||
if T + max_new_tokens > model.config.max_seq_len:
|
||||
@ -246,7 +262,6 @@ def generate(
|
||||
device, dtype = prompt.device, prompt.dtype
|
||||
|
||||
codebook_dim = 1 + model.config.num_codebooks
|
||||
# create an empty tensor of the expected final shape and fill in the current tokens
|
||||
empty = torch.empty(
|
||||
(codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
|
||||
)
|
||||
@ -257,33 +272,30 @@ def generate(
|
||||
# Use non-accelerated version for now, to avoid compilation overhead
|
||||
prefill_decode = decode_one_token_ar
|
||||
|
||||
next_token = prefill_decode(
|
||||
first_token = prefill_decode(
|
||||
model,
|
||||
prompt.view(1, codebook_dim, -1),
|
||||
input_pos,
|
||||
semantic_ids=semantic_ids,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
seq[:, T : T + 1] = next_token
|
||||
seq[:, T : T + 1] = first_token
|
||||
|
||||
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
||||
x = decode_n_tokens(
|
||||
model,
|
||||
next_token.view(1, codebook_dim, -1),
|
||||
first_token.view(1, codebook_dim, -1),
|
||||
input_pos,
|
||||
max_new_tokens - 1,
|
||||
decode_one_token=decode_one_token,
|
||||
semantic_ids=semantic_ids,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
# x = torch.cat(generated_tokens, dim=1)
|
||||
seq = seq[:, : T + 1 + x.size(1)]
|
||||
seq[:, T + 1 :] = x
|
||||
|
||||
return seq
|
||||
|
||||
|
||||
def load_model(checkpoint_path, device, precision, compile=False):
|
||||
def init_model(checkpoint_path, device, precision, compile=False):
|
||||
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
||||
|
||||
model = model.to(device=device, dtype=precision)
|
||||
@ -405,26 +417,6 @@ def generate_long(
|
||||
seg = encoded[seg_idx]
|
||||
global_encoded.append(seg)
|
||||
|
||||
# Do not use previous segments to generate current segment for now
|
||||
# lengths = reversed([seg.size(1) for seg in global_encoded])
|
||||
|
||||
# # Pick last 2000 tokens
|
||||
# count = 0
|
||||
# for i, length in enumerate(lengths):
|
||||
# count += length
|
||||
# if count + length > max_length - 2048 - encoded_prompts.size(1):
|
||||
# break
|
||||
|
||||
# if i != 0 and i % 2 == 0:
|
||||
# i -= 1
|
||||
|
||||
# # Rotate the list, always make sure first segment is included to avoid drift
|
||||
# if i < len(global_encoded) - 2:
|
||||
# partial_encoded = global_encoded[:2] + global_encoded[-i:]
|
||||
# else:
|
||||
# partial_encoded = global_encoded
|
||||
|
||||
# cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1)
|
||||
if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
|
||||
cat_encoded = torch.cat(
|
||||
[encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
|
||||
@ -507,7 +499,7 @@ def launch_thread_safe_queue(
|
||||
init_event = threading.Event()
|
||||
|
||||
def worker():
|
||||
model, decode_one_token = load_model(
|
||||
model, decode_one_token = init_model(
|
||||
checkpoint_path, device, precision, compile=compile
|
||||
)
|
||||
with torch.device(device):
|
||||
@ -542,60 +534,6 @@ def launch_thread_safe_queue(
|
||||
return input_queue
|
||||
|
||||
|
||||
def launch_thread_safe_queue_agent(
|
||||
checkpoint_path,
|
||||
device,
|
||||
precision,
|
||||
compile: bool = False,
|
||||
):
|
||||
input_queue = queue.Queue()
|
||||
init_event = threading.Event()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
||||
config = BaseModelArgs.from_pretrained(checkpoint_path)
|
||||
|
||||
def worker():
|
||||
model, decode_one_token = load_model(
|
||||
checkpoint_path, device, precision, compile=compile, is_agent=True
|
||||
)
|
||||
|
||||
with torch.device(device):
|
||||
model.setup_caches(
|
||||
max_batch_size=1,
|
||||
max_seq_len=model.config.max_seq_len,
|
||||
dtype=next(model.parameters()).dtype,
|
||||
)
|
||||
init_event.set()
|
||||
|
||||
while True:
|
||||
item: GenerateRequest | None = input_queue.get()
|
||||
if item is None:
|
||||
break
|
||||
|
||||
kwargs = item.request
|
||||
response_queue = item.response_queue
|
||||
|
||||
try:
|
||||
for token in generate_agent(
|
||||
model=model,
|
||||
decode_one_token=decode_one_token,
|
||||
**kwargs,
|
||||
):
|
||||
response_queue.put(token)
|
||||
|
||||
response_queue.put("stop")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.exception(f"Error in worker: {traceback.format_exc()}")
|
||||
response_queue.put("error")
|
||||
|
||||
threading.Thread(target=worker, daemon=True).start()
|
||||
init_event.wait()
|
||||
|
||||
return input_queue, tokenizer, config
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--text",
|
||||
@ -654,7 +592,7 @@ def main(
|
||||
|
||||
logger.info("Loading model ...")
|
||||
t0 = time.time()
|
||||
model, decode_one_token = load_model(
|
||||
model, decode_one_token = init_model(
|
||||
checkpoint_path, device, precision, compile=compile
|
||||
)
|
||||
with torch.device(device):
|
||||
|
Loading…
x
Reference in New Issue
Block a user