可用版本v1.0

This commit is contained in:
2025-12-18 01:52:40 +08:00
commit 45dbba7d6a
7 changed files with 1510 additions and 0 deletions

1
.cursor Submodule

Submodule .cursor added at 67480b7ec2

2
.gitattributes vendored Normal file
View File

@@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto

190
.gitignore vendored Normal file
View File

@@ -0,0 +1,190 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor.`.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
# IDE
.vscode/
# Database
liubai_web.pid
Assets/
# StreamingAssets
StreamingAssets/

6
.gitmodules vendored Normal file
View File

@@ -0,0 +1,6 @@
[submodule ".cursor"]
path = .cursor
url = http://gitea.liubai.site/ninemine/.cursor.git
[submodule "Convention"]
path = Convention
url = http://gitea.liubai.site/ninemine/Convention-Python.git

1
Convention Submodule

Submodule Convention added at ad17b905c4

355
cli.py Normal file
View File

@@ -0,0 +1,355 @@
from Convention.Runtime.GlobalConfig import *
import asyncio
from datetime import datetime
from llama_index.llms.ollama import Ollama
from llama_index.core.chat_engine import SimpleChatEngine
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core import Settings
from Convention.Runtime.File import ToolFile
import requests
import subprocess
import wave
import io
import numpy as np
import pyaudio
chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S")
config = ProjectConfig()
OLLAMA_URL = config.FindItem("ollama_url", "http://localhost:11434")
OLLAMA_MODEL = config.FindItem("ollama_model", "gemma3:4b")
RESPONSE_TIMEOUT = config.FindItem("response_timeout", 60)
TEMPERATURE = config.FindItem("temperature", 1.3)
MAX_CONTENT_LENGTH = config.FindItem("max_content_length", None)
SYSTEM_PROMPT_PATH = config.FindItem("system_prompt_path", None)
AUTO_SPEAK_WAIT_SECOND = config.FindItem("auto_speak_wait_second", 15.0)
TTS_SERVER_URL = config.FindItem("tts_server_url", "http://localhost:43400")
TTS_PROMPT_TEXT = config.FindItem("tts_prompt_text", None)
TTS_PROMPT_WAV_PATH = config.FindItem("tts_prompt_wav_path", None)
TTS_SPEAKER_ID = config.FindItem("tts_speaker_id", "tts_speaker")
SEED = config.FindItem("seed", 0)
STREAM_ENABLE = config.FindItem("stream_enable", False)
VERBOSE = config.FindItem("verbose", False)
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_URL: {OLLAMA_URL}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_MODEL: {OLLAMA_MODEL}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"RESPONSE_TIMEOUT: {RESPONSE_TIMEOUT}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"TEMPERATURE: {TEMPERATURE}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"SYSTEM_PROMPT_PATH: {SYSTEM_PROMPT_PATH}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"AUTO_SPEAK_WAIT_SECOND: {AUTO_SPEAK_WAIT_SECOND}")
temp_dir = config.GetFile("temp")|chat_start_id|None
ollama_llm_config = {
"model": OLLAMA_MODEL,
"base_url": OLLAMA_URL,
"request_timeout": RESPONSE_TIMEOUT,
"temperature": TEMPERATURE,
}
chat_engine_config = {}
if MAX_CONTENT_LENGTH is not None:
ollama_llm_config["max_content_length"] = MAX_CONTENT_LENGTH
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
if SYSTEM_PROMPT_PATH is not None:
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
chat_engine_config["system_prompt"] = system_prompt
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"system_prompt: {system_prompt}")
config.SaveProperties()
def save_vocal_data(data:bytes) -> ToolFile:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}.wav"
file = temp_dir|filename
file.MustExistsPath()
file.SaveAsBinary(data)
return file
audio_play_queue: asyncio.Queue[bytes] = asyncio.Queue()
_pyaudio_instance: pyaudio.PyAudio | None = None
_pyaudio_stream: pyaudio.Stream | None = None
_current_sample_rate: int | None = None
# TOCHECK
def parse_wav_chunk(wav_bytes: bytes) -> tuple[np.ndarray, int]:
"""
解析WAV数据返回音频数组和采样率
"""
wav_file = wave.open(io.BytesIO(wav_bytes))
sample_rate = wav_file.getframerate()
n_channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
n_frames = wav_file.getnframes()
audio_bytes = wav_file.readframes(n_frames)
wav_file.close()
if sample_width == 2:
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
elif sample_width == 4:
audio_data_32 = np.frombuffer(audio_bytes, dtype=np.int32)
max_val = np.abs(audio_data_32).max()
if max_val > 0:
audio_data = (audio_data_32 / max_val * 32767).astype(np.int16)
else:
audio_data = np.zeros(len(audio_data_32), dtype=np.int16)
else:
raise ValueError(f"Unsupported sample width: {sample_width}")
if n_channels == 2:
audio_data = audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16)
audio_data = np.clip(audio_data, -32768, 32767).astype(np.int16)
return audio_data, sample_rate
# TOCHECK
def play_audio_chunk(audio_data: np.ndarray, sample_rate: int) -> None:
"""
使用PyAudio播放音频数组
"""
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
if _pyaudio_instance is None or _current_sample_rate != sample_rate:
if _pyaudio_stream is not None:
try:
_pyaudio_stream.stop_stream()
_pyaudio_stream.close()
except Exception:
pass
if _pyaudio_instance is not None:
try:
_pyaudio_instance.terminate()
except Exception:
pass
frames_per_buffer = max(int(sample_rate * 0.02), 256)
_pyaudio_instance = pyaudio.PyAudio()
_pyaudio_stream = _pyaudio_instance.open(
format=pyaudio.paInt16,
channels=1,
rate=sample_rate,
output=True,
frames_per_buffer=frames_per_buffer
)
_current_sample_rate = sample_rate
if audio_data.dtype != np.int16:
max_val = np.abs(audio_data).max()
if max_val > 0:
audio_data = (audio_data / max_val * 32767).astype(np.int16)
else:
audio_data = np.zeros_like(audio_data, dtype=np.int16)
chunk_size = 4096
audio_bytes = audio_data.tobytes()
for i in range(0, len(audio_bytes), chunk_size):
chunk = audio_bytes[i:i + chunk_size]
if _pyaudio_stream is not None:
_pyaudio_stream.write(chunk)
# TOCHECK
def cleanup_audio() -> None:
"""
释放PyAudio资源
"""
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
if _pyaudio_stream is not None:
try:
_pyaudio_stream.stop_stream()
_pyaudio_stream.close()
except Exception:
pass
_pyaudio_stream = None
if _pyaudio_instance is not None:
try:
_pyaudio_instance.terminate()
except Exception:
pass
_pyaudio_instance = None
_current_sample_rate = None
# TOCHECK
def play_audio_sync(audio_data: bytes) -> None:
if not audio_data:
return
audio_array, sample_rate = parse_wav_chunk(audio_data)
play_audio_chunk(audio_array, sample_rate)
# TOCHECK
async def audio_player_worker():
"""
音频播放后台任务,确保音频按顺序播放
"""
while True:
audio_data = await audio_play_queue.get()
if audio_data is None:
audio_play_queue.task_done()
break
try:
await asyncio.to_thread(play_audio_sync, audio_data)
await asyncio.to_thread(save_vocal_data, audio_data)
except Exception as exc:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"音频播放失败: {exc}")
finally:
audio_play_queue.task_done()
# CHANGE TOCHECK
async def play_vocal(text:str) -> None:
if len(text) == 0 or not text:
return
tts_server_url = f"{TTS_SERVER_URL}/api/synthesis/sft"
# 准备请求数据
header = {
"accept": "application/json"
}
data = {
'text': text,
'speaker_id': TTS_SPEAKER_ID,
'stream': STREAM_ENABLE,
}
# 发送POST请求
if STREAM_ENABLE:
response = requests.post(tts_server_url, data=data, stream=True, timeout=600, headers=header)
if response.status_code != 200:
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
wav_buffer = bytearray()
for chunk in response.iter_content(chunk_size=1024 * 256):
if not chunk:
continue
wav_buffer.extend(chunk)
while len(wav_buffer) > 12:
if wav_buffer[:4] != b'RIFF':
riff_pos = wav_buffer.find(b'RIFF', 1)
if riff_pos == -1:
wav_buffer.clear()
break
wav_buffer = wav_buffer[riff_pos:]
if len(wav_buffer) < 8:
break
file_size = int.from_bytes(wav_buffer[4:8], byteorder='little')
expected_size = file_size + 8
if len(wav_buffer) < expected_size:
break
complete_wav = bytes(wav_buffer[:expected_size])
del wav_buffer[:expected_size]
await audio_play_queue.put(complete_wav)
if wav_buffer:
leftover = bytes(wav_buffer)
try:
parse_wav_chunk(leftover)
await audio_play_queue.put(leftover)
except Exception:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, "剩余音频数据解析失败,已丢弃")
else:
response = requests.post(tts_server_url, data=data, timeout=600, headers=header)
if response.status_code == 200:
await audio_play_queue.put(response.content)
else:
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
async def ainput(wait_seconds:float) -> str:
loop = asyncio.get_event_loop()
def get_input():
try:
return input("\n你: ")
except EOFError:
return ""
input_task = loop.run_in_executor(None, get_input)
while wait_seconds > 0:
if input_task.done():
return input_task.result()
await asyncio.sleep(0.5)
wait_seconds -= 0.5
return ""
async def achat(engine:SimpleChatEngine,message:str) -> None:
user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)"
streaming_response: StreamingAgentChatResponse = await engine.astream_chat(user_message)
buffer_response = ""
end_symbol = ['', '', '']
# 实时输出流式文本
async for chunk in streaming_response.async_response_gen():
await asyncio.sleep(0.01)
print(chunk, end='', flush=True)
for ch in chunk:
buffer_response += ch
if len(buffer_response) > 20:
if ch in end_symbol:
await play_vocal(buffer_response.strip())
buffer_response = ""
buffer_response = buffer_response.strip()
if len(buffer_response) > 0:
await play_vocal(buffer_response)
def add_speaker() -> None:
url = f"{TTS_SERVER_URL}/api/speakers/add"
headers = {
"accept": "application/json"
}
data = {
"speaker_id": TTS_SPEAKER_ID,
"prompt_text": TTS_PROMPT_TEXT,
"force_regenerate": True
}
with open(TTS_PROMPT_WAV_PATH, 'rb') as f:
extension = ToolFile(TTS_PROMPT_WAV_PATH).GetExtension().lower()
files = {
'prompt_wav': (f'prompt.{extension}', f, f'audio/{extension}')
}
response = requests.post(url, data=data, files=files, headers=headers, timeout=600)
if response.status_code == 200:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"音色可用: {response.text}")
else:
raise Exception(f"添加音色失败: {response.status_code} - {response.text}")
async def event_loop(engine:SimpleChatEngine) -> None:
add_speaker()
audio_player_task = asyncio.create_task(audio_player_worker())
message = input("请开始对话: ")
wait_second = AUTO_SPEAK_WAIT_SECOND
try:
while message != "quit" and message != "exit":
PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='')
await achat(engine, message)
PrintColorful(ConsoleFrontColor.RESET,"")
message = await ainput(wait_second)
if not message:
wait_second = max(wait_second*1.5, 3600)
else:
wait_second = AUTO_SPEAK_WAIT_SECOND
finally:
await audio_play_queue.join()
await audio_play_queue.put(None)
await audio_player_task
cleanup_audio()
async def main():
# Initialize
try:
ollama_llm = Ollama(**ollama_llm_config)
Settings.llm = ollama_llm
chat_engine = SimpleChatEngine.from_defaults(**chat_engine_config)
await event_loop(chat_engine)
except Exception as e:
config.Log("Error", f"Error: {e}")
return
finally:
cleanup_audio()
if __name__ == "__main__":
asyncio.run(main())

955
web_server.py Normal file
View File

@@ -0,0 +1,955 @@
from Convention.Runtime.GlobalConfig import *
import asyncio
from datetime import datetime
from llama_index.llms.ollama import Ollama
from llama_index.core.chat_engine import SimpleChatEngine
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core import Settings
from Convention.Runtime.File import ToolFile
import requests
import wave
import io
import base64
import json
import time
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
from typing import Optional, Set
from pydantic import BaseModel
chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S")
# 全局变量
config: Optional[ProjectConfig] = None
chat_engine: Optional[SimpleChatEngine] = None
connected_clients: Set[WebSocket] = set()
temp_dir: Optional[ToolFile] = None
last_message_time: float = 0.0
auto_speak_task: Optional[asyncio.Task] = None
is_processing: bool = False # 标记是否正在处理消息
current_wait_second: float = 15.0 # 当前等待间隔(动态调整)
last_user_message_time: float = 0.0 # 最后用户消息时间(用于检测用户回应)
# 配置变量
OLLAMA_URL: str = "http://localhost:11434"
OLLAMA_MODEL: str = "gemma3:4b"
RESPONSE_TIMEOUT: int = 60
TEMPERATURE: float = 1.3
MAX_CONTENT_LENGTH: Optional[int] = None
SYSTEM_PROMPT_PATH: Optional[str] = None
TTS_SERVER_URL: str = "http://localhost:43400"
TTS_PROMPT_TEXT: Optional[str] = None
TTS_PROMPT_WAV_PATH: Optional[str] = None
TTS_SPEAKER_ID: str = "tts_speaker"
STREAM_ENABLE: bool = False
VERBOSE: bool = False
AUTO_SPEAK_WAIT_SECOND: float = 15.0
def initialize_config():
"""初始化配置"""
global config, temp_dir, OLLAMA_URL, OLLAMA_MODEL, RESPONSE_TIMEOUT, TEMPERATURE
global MAX_CONTENT_LENGTH, SYSTEM_PROMPT_PATH, TTS_SERVER_URL, TTS_PROMPT_TEXT
global TTS_PROMPT_WAV_PATH, TTS_SPEAKER_ID, STREAM_ENABLE, VERBOSE, AUTO_SPEAK_WAIT_SECOND
config = ProjectConfig()
OLLAMA_URL = config.FindItem("ollama_url", "http://localhost:11434")
OLLAMA_MODEL = config.FindItem("ollama_model", "gemma3:4b")
RESPONSE_TIMEOUT = config.FindItem("response_timeout", 60)
TEMPERATURE = config.FindItem("temperature", 1.3)
MAX_CONTENT_LENGTH = config.FindItem("max_content_length", None)
SYSTEM_PROMPT_PATH = config.FindItem("system_prompt_path", None)
TTS_SERVER_URL = config.FindItem("tts_server_url", "http://localhost:43400")
TTS_PROMPT_TEXT = config.FindItem("tts_prompt_text", None)
TTS_PROMPT_WAV_PATH = config.FindItem("tts_prompt_wav_path", None)
TTS_SPEAKER_ID = config.FindItem("tts_speaker_id", "tts_speaker")
STREAM_ENABLE = config.FindItem("stream_enable", False)
VERBOSE = config.FindItem("verbose", False)
AUTO_SPEAK_WAIT_SECOND = config.FindItem("auto_speak_wait_second", 15.0)
temp_dir = config.GetFile("temp") | chat_start_id | None
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"OLLAMA_URL: {OLLAMA_URL}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"OLLAMA_MODEL: {OLLAMA_MODEL}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"RESPONSE_TIMEOUT: {RESPONSE_TIMEOUT}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"TEMPERATURE: {TEMPERATURE}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"AUTO_SPEAK_WAIT_SECOND: {AUTO_SPEAK_WAIT_SECOND}")
config.SaveProperties()
def initialize_chat_engine():
"""初始化聊天引擎"""
global chat_engine
ollama_llm_config = {
"model": OLLAMA_MODEL,
"base_url": OLLAMA_URL,
"request_timeout": RESPONSE_TIMEOUT,
"temperature": TEMPERATURE,
}
chat_engine_config = {}
if MAX_CONTENT_LENGTH is not None:
ollama_llm_config["max_content_length"] = MAX_CONTENT_LENGTH
if SYSTEM_PROMPT_PATH is not None:
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
chat_engine_config["system_prompt"] = system_prompt
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"system_prompt loaded")
ollama_llm = Ollama(**ollama_llm_config)
Settings.llm = ollama_llm
chat_engine = SimpleChatEngine.from_defaults(**chat_engine_config)
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "Chat engine initialized")
def save_vocal_data(data: bytes) -> ToolFile:
"""保存音频数据"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}.wav"
file = temp_dir | filename
file.MustExistsPath()
file.SaveAsBinary(data)
return file
async def generate_tts_audio(text: str) -> Optional[bytes]:
"""生成TTS音频返回音频字节数据"""
if len(text) == 0 or not text:
return None
if TTS_PROMPT_WAV_PATH is None:
return None
tts_server_url = f"{TTS_SERVER_URL}/api/synthesis/sft"
header = {
"accept": "application/json"
}
data = {
'text': text,
'speaker_id': TTS_SPEAKER_ID,
'stream': STREAM_ENABLE,
}
def _generate_sync():
"""同步生成TTS音频"""
try:
if STREAM_ENABLE:
response = requests.post(tts_server_url, data=data, stream=True, timeout=600, headers=header)
if response.status_code != 200:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"TTS失败: {response.status_code} - {response.text}")
return None
wav_buffer = bytearray()
for chunk in response.iter_content(chunk_size=1024 * 256):
if not chunk:
continue
wav_buffer.extend(chunk)
if wav_buffer:
complete_wav = bytes(wav_buffer)
save_vocal_data(complete_wav)
return complete_wav
return None
else:
response = requests.post(tts_server_url, data=data, timeout=600, headers=header)
if response.status_code == 200:
audio_data = response.content
save_vocal_data(audio_data)
return audio_data
else:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"TTS失败: {response.status_code} - {response.text}")
return None
except Exception as e:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"TTS异常: {e}")
return None
# 在线程池中执行同步操作
return await asyncio.to_thread(_generate_sync)
def add_speaker() -> None:
"""添加TTS音色"""
if TTS_PROMPT_WAV_PATH is None or TTS_PROMPT_TEXT is None:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "TTS音色配置不完整跳过初始化")
return
url = f"{TTS_SERVER_URL}/api/speakers/add"
headers = {
"accept": "application/json"
}
data = {
"speaker_id": TTS_SPEAKER_ID,
"prompt_text": TTS_PROMPT_TEXT,
"force_regenerate": True
}
try:
with open(TTS_PROMPT_WAV_PATH, 'rb') as f:
extension = ToolFile(TTS_PROMPT_WAV_PATH).GetExtension().lower()
files = {
'prompt_wav': (f'prompt.{extension}', f, f'audio/{extension}')
}
response = requests.post(url, data=data, files=files, headers=headers, timeout=600)
if response.status_code == 200:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"音色可用: {response.text}")
else:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"添加音色失败: {response.status_code} - {response.text}")
except Exception as e:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"添加音色异常: {e}")
async def auto_speak_task_func():
"""自动发言后台任务"""
global chat_engine, last_message_time, config, connected_clients, is_processing, current_wait_second, last_user_message_time
# 初始化最后消息时间和等待间隔
last_message_time = time.time()
last_user_message_time = time.time()
current_wait_second = AUTO_SPEAK_WAIT_SECOND
while True:
# 使用动态等待间隔
await asyncio.sleep(current_wait_second)
if chat_engine is None or config is None:
continue
current_time = time.time()
# 检查是否有客户端连接、超过等待时间且当前没有正在处理的消息
# 使用 last_user_message_time 来判断是否超过等待时间(只考虑用户消息,不包括自动发言)
if (connected_clients and
not is_processing and
(current_time - last_user_message_time) >= current_wait_second):
try:
# 触发自动发言
auto_message = "(没有人说话, 请延续发言或是寻找新的话题)"
# 向所有连接的客户端发送自动发言
for websocket in list(connected_clients):
try:
await handle_chat_stream(websocket, auto_message)
except Exception as e:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"自动发言发送错误: {e}")
# 等待一段时间让用户有机会回应等待当前间隔的10%但至少1秒
check_wait_time = max(current_wait_second * 0.1, 1.0)
await asyncio.sleep(check_wait_time)
# 检查是否有用户回应:比较 last_user_message_time 是否在自动发言后被更新
time_after_check = time.time()
if (time_after_check - last_user_message_time) > check_wait_time + 2.0:
# 没有用户回应,逐渐增加等待间隔
current_wait_second = min(current_wait_second * 1.5, 3600.0)
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"无用户回应,等待间隔调整为: {current_wait_second}")
# 如果有用户回应,等待间隔会在 handle_chat_stream 中重置
except Exception as e:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"自动发言任务错误: {e}")
# FastAPI应用
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
global auto_speak_task, last_message_time
# 启动时初始化
initialize_config()
initialize_chat_engine()
add_speaker()
last_message_time = time.time()
last_user_message_time = time.time()
# 启动自动发言任务
auto_speak_task = asyncio.create_task(auto_speak_task_func())
yield
# 关闭时清理
if auto_speak_task and not auto_speak_task.done():
auto_speak_task.cancel()
try:
await auto_speak_task
except asyncio.CancelledError:
pass
if config:
config.SaveProperties()
app = FastAPI(lifespan=lifespan)
# WebSocket连接管理
async def connect_client(websocket: WebSocket):
"""添加客户端连接"""
await websocket.accept()
connected_clients.add(websocket)
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, f"客户端已连接,当前连接数: {len(connected_clients)}")
async def disconnect_client(websocket: WebSocket):
"""移除客户端连接"""
connected_clients.discard(websocket)
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"客户端已断开,当前连接数: {len(connected_clients)}")
# API模型
class ChatRequest(BaseModel):
message: str
async def safe_send_json(websocket: WebSocket, data: dict) -> bool:
"""安全地发送 JSON 消息,检查连接状态"""
if websocket not in connected_clients:
return False
try:
await websocket.send_json(data)
return True
except (WebSocketDisconnect, RuntimeError, ConnectionError):
if websocket in connected_clients:
await disconnect_client(websocket)
return False
except Exception as e:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"发送消息错误: {e}")
return False
async def generate_and_send_audio(websocket: WebSocket, text: str):
"""生成音频并发送到客户端"""
try:
# 检查 WebSocket 是否仍然连接
if websocket not in connected_clients:
return
audio_data = await generate_tts_audio(text)
# 再次检查连接状态(可能在生成音频期间断开)
if websocket not in connected_clients:
return
if audio_data:
audio_base64 = base64.b64encode(audio_data).decode('utf-8')
await safe_send_json(websocket, {
"type": "audio",
"audio": audio_base64
})
except (WebSocketDisconnect, RuntimeError, ConnectionError) as e:
# 连接已关闭,忽略错误
if websocket in connected_clients:
await disconnect_client(websocket)
except Exception as e:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"音频生成错误: {e}")
# WebSocket端点
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket端点处理流式聊天"""
await connect_client(websocket)
try:
while True:
# 接收客户端消息
data = await websocket.receive_text()
message_data = json.loads(data)
if message_data.get("type") == "chat":
message = message_data.get("message", "")
if not message:
await safe_send_json(websocket, {
"type": "error",
"message": "消息不能为空"
})
continue
# 处理聊天请求
await handle_chat_stream(websocket, message)
elif message_data.get("type") == "ping":
await safe_send_json(websocket, {"type": "pong"})
except WebSocketDisconnect:
await disconnect_client(websocket)
except Exception as e:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"WebSocket错误: {e}")
await disconnect_client(websocket)
async def handle_chat_stream(websocket: WebSocket, message: str):
"""处理流式聊天"""
global chat_engine, last_message_time, is_processing, current_wait_second, last_user_message_time
if chat_engine is None:
await safe_send_json(websocket, {
"type": "error",
"message": "聊天引擎未初始化"
})
return
# 更新最后消息时间和处理状态
last_message_time = time.time()
is_processing = True
# 如果是用户消息(不是自动发言),更新用户消息时间并重置等待间隔
if message != "(没有人说话, 请延续发言或是寻找新的话题)":
last_user_message_time = time.time()
current_wait_second = AUTO_SPEAK_WAIT_SECOND
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, f"用户发送消息,等待间隔重置为: {current_wait_second}")
user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)"
try:
streaming_response: StreamingAgentChatResponse = await chat_engine.astream_chat(user_message)
buffer_response = ""
end_symbol = ['', '', '']
audio_text_buffer = ""
# 发送开始消息
if not await safe_send_json(websocket, {
"type": "start",
"message": ""
}):
return
# 流式输出
async for chunk in streaming_response.async_response_gen():
# 检查连接状态
if websocket not in connected_clients:
break
await asyncio.sleep(0.01)
buffer_response += chunk
# 发送文本块
if not await safe_send_json(websocket, {
"type": "chunk",
"message": chunk
}):
break
# 检查是否需要生成音频
for ch in chunk:
audio_text_buffer += ch
if len(audio_text_buffer) > 20:
if ch in end_symbol:
text_to_speak = audio_text_buffer.strip()
if text_to_speak:
# 异步生成音频(不阻塞流式输出)
asyncio.create_task(generate_and_send_audio(websocket, text_to_speak))
audio_text_buffer = ""
# 检查连接状态
if websocket not in connected_clients:
return
# 处理剩余文本
if buffer_response.strip():
# 发送完成消息
await safe_send_json(websocket, {
"type": "complete",
"message": buffer_response.strip()
})
# 生成剩余音频
if audio_text_buffer.strip():
asyncio.create_task(generate_and_send_audio(websocket, audio_text_buffer.strip()))
else:
await safe_send_json(websocket, {
"type": "complete",
"message": ""
})
except (WebSocketDisconnect, RuntimeError, ConnectionError):
if websocket in connected_clients:
await disconnect_client(websocket)
except Exception as e:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"聊天处理错误: {e}")
await safe_send_json(websocket, {
"type": "error",
"message": f"处理请求时发生错误: {str(e)}"
})
finally:
# 重置处理状态
is_processing = False
# 静态文件服务
@app.get("/", response_class=HTMLResponse)
async def read_root():
"""返回前端页面"""
return HTMLResponse(content=get_html_content())
def get_html_content() -> str:
"""生成HTML内容"""
return """<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>VirtualChat - AI对话</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
padding: 20px;
}
.container {
width: 100%;
max-width: 800px;
height: 90vh;
background: rgba(255, 255, 255, 0.95);
border-radius: 24px;
box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
display: flex;
flex-direction: column;
overflow: hidden;
backdrop-filter: blur(10px);
}
.header {
padding: 24px;
border-bottom: 1px solid rgba(0, 0, 0, 0.1);
background: rgba(255, 255, 255, 0.8);
backdrop-filter: blur(10px);
}
.header h1 {
font-size: 24px;
font-weight: 600;
color: #1d1d1f;
text-align: center;
}
.messages {
flex: 1;
overflow-y: auto;
padding: 24px;
display: flex;
flex-direction: column;
gap: 16px;
}
.messages::-webkit-scrollbar {
width: 6px;
}
.messages::-webkit-scrollbar-track {
background: transparent;
}
.messages::-webkit-scrollbar-thumb {
background: rgba(0, 0, 0, 0.2);
border-radius: 3px;
}
.message {
display: flex;
gap: 12px;
animation: fadeIn 0.3s ease-in;
}
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.message.user {
justify-content: flex-end;
}
.message-content {
max-width: 70%;
padding: 14px 18px;
border-radius: 20px;
word-wrap: break-word;
line-height: 1.5;
font-size: 15px;
}
.message.user .message-content {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-bottom-right-radius: 4px;
}
.message.ai .message-content {
background: #f5f5f7;
color: #1d1d1f;
border-bottom-left-radius: 4px;
}
.input-area {
padding: 24px;
border-top: 1px solid rgba(0, 0, 0, 0.1);
background: rgba(255, 255, 255, 0.8);
backdrop-filter: blur(10px);
}
.input-container {
display: flex;
gap: 12px;
align-items: flex-end;
}
.input-wrapper {
flex: 1;
position: relative;
}
#messageInput {
width: 100%;
padding: 14px 18px;
border: 2px solid rgba(0, 0, 0, 0.1);
border-radius: 24px;
font-size: 15px;
font-family: inherit;
outline: none;
transition: all 0.3s ease;
background: white;
resize: none;
min-height: 50px;
max-height: 120px;
}
#messageInput:focus {
border-color: #667eea;
box-shadow: 0 0 0 4px rgba(102, 126, 234, 0.1);
}
.send-button {
padding: 14px 28px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 24px;
font-size: 15px;
font-weight: 600;
cursor: pointer;
transition: all 0.3s ease;
white-space: nowrap;
}
.send-button:hover:not(:disabled) {
transform: translateY(-2px);
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.4);
}
.send-button:active:not(:disabled) {
transform: translateY(0);
}
.send-button:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.status {
text-align: center;
padding: 12px;
color: #86868b;
font-size: 13px;
}
.status.connected {
color: #30d158;
}
.status.disconnected {
color: #ff453a;
}
.typing-indicator {
display: none;
padding: 14px 18px;
background: #f5f5f7;
border-radius: 20px;
border-bottom-left-radius: 4px;
max-width: 70px;
}
.typing-indicator.active {
display: block;
}
.typing-dots {
display: flex;
gap: 4px;
}
.typing-dot {
width: 8px;
height: 8px;
background: #86868b;
border-radius: 50%;
animation: typing 1.4s infinite;
}
.typing-dot:nth-child(2) {
animation-delay: 0.2s;
}
.typing-dot:nth-child(3) {
animation-delay: 0.4s;
}
@keyframes typing {
0%, 60%, 100% {
transform: translateY(0);
}
30% {
transform: translateY(-10px);
}
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>VirtualChat</h1>
</div>
<div class="messages" id="messages"></div>
<div class="input-area">
<div class="input-container">
<div class="input-wrapper">
<textarea id="messageInput" placeholder="输入消息..." rows="1"></textarea>
</div>
<button class="send-button" id="sendButton">发送</button>
</div>
<div class="status" id="status">连接中...</div>
</div>
</div>
<script>
let ws = null;
let currentAiMessage = null;
let audioQueue = [];
let isPlayingAudio = false;
const messagesDiv = document.getElementById('messages');
const messageInput = document.getElementById('messageInput');
const sendButton = document.getElementById('sendButton');
const statusDiv = document.getElementById('status');
function connect() {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${protocol}//${window.location.host}/ws`;
ws = new WebSocket(wsUrl);
ws.onopen = () => {
statusDiv.textContent = '已连接';
statusDiv.className = 'status connected';
sendButton.disabled = false;
};
ws.onclose = () => {
statusDiv.textContent = '已断开';
statusDiv.className = 'status disconnected';
sendButton.disabled = true;
setTimeout(connect, 3000);
};
ws.onerror = (error) => {
console.error('WebSocket error:', error);
statusDiv.textContent = '连接错误';
statusDiv.className = 'status disconnected';
};
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
handleMessage(data);
};
}
function handleMessage(data) {
if (data.type === 'start') {
currentAiMessage = createMessage('ai', '');
showTypingIndicator();
} else if (data.type === 'chunk') {
hideTypingIndicator();
if (currentAiMessage) {
currentAiMessage.textContent += data.message;
scrollToBottom();
}
} else if (data.type === 'complete') {
hideTypingIndicator();
if (currentAiMessage && data.message) {
currentAiMessage.textContent = data.message;
}
currentAiMessage = null;
scrollToBottom();
} else if (data.type === 'audio') {
playAudio(data.audio);
} else if (data.type === 'error') {
hideTypingIndicator();
addMessage('ai', '错误: ' + data.message);
} else if (data.type === 'auto_speak') {
// 自动发言消息(与普通消息处理相同)
if (data.message) {
addMessage('ai', data.message);
}
if (data.audio) {
playAudio(data.audio);
}
} else if (data.type === 'pong') {
// 心跳响应
}
}
function createMessage(role, text) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${role}`;
const contentDiv = document.createElement('div');
contentDiv.className = 'message-content';
contentDiv.textContent = text;
messageDiv.appendChild(contentDiv);
messagesDiv.appendChild(messageDiv);
scrollToBottom();
return contentDiv;
}
function addMessage(role, text) {
createMessage(role, text);
}
function showTypingIndicator() {
let indicator = document.getElementById('typingIndicator');
if (!indicator) {
indicator = document.createElement('div');
indicator.id = 'typingIndicator';
indicator.className = 'message ai';
indicator.innerHTML = `
<div class="typing-indicator active">
<div class="typing-dots">
<div class="typing-dot"></div>
<div class="typing-dot"></div>
<div class="typing-dot"></div>
</div>
</div>
`;
messagesDiv.appendChild(indicator);
} else {
indicator.querySelector('.typing-indicator').classList.add('active');
}
scrollToBottom();
}
function hideTypingIndicator() {
const indicator = document.getElementById('typingIndicator');
if (indicator) {
indicator.querySelector('.typing-indicator').classList.remove('active');
}
}
function scrollToBottom() {
messagesDiv.scrollTop = messagesDiv.scrollHeight;
}
function playAudio(audioBase64) {
audioQueue.push(audioBase64);
processAudioQueue();
}
function processAudioQueue() {
if (isPlayingAudio || audioQueue.length === 0) {
return;
}
isPlayingAudio = true;
const audioBase64 = audioQueue.shift();
const audio = new Audio('data:audio/wav;base64,' + audioBase64);
audio.onended = () => {
isPlayingAudio = false;
processAudioQueue();
};
audio.onerror = () => {
isPlayingAudio = false;
processAudioQueue();
};
audio.play().catch(err => {
console.error('Audio play error:', err);
isPlayingAudio = false;
processAudioQueue();
});
}
function sendMessage() {
const message = messageInput.value.trim();
if (!message || !ws || ws.readyState !== WebSocket.OPEN) {
return;
}
addMessage('user', message);
messageInput.value = '';
messageInput.style.height = 'auto';
ws.send(JSON.stringify({
type: 'chat',
message: message
}));
}
sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keydown', (e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendMessage();
}
});
messageInput.addEventListener('input', () => {
messageInput.style.height = 'auto';
messageInput.style.height = messageInput.scrollHeight + 'px';
});
connect();
</script>
</body>
</html>"""
# 健康检查
@app.get("/health")
def health_check():
"""健康检查端点"""
return {
"status": "ok",
"chat_engine_ready": chat_engine is not None,
"connected_clients": len(connected_clients)
}
# 主程序入口
if __name__ == "__main__":
import uvicorn
server_host = config.FindItem("server_host", "0.0.0.0") if config else "0.0.0.0"
server_port = config.FindItem("server_port", 11451) if config else 11451
uvicorn.run(app, host=server_host, port=server_port)