From 45dbba7d6a002e1550d352c2e077bfda3f89c9f0 Mon Sep 17 00:00:00 2001 From: ninemine <1371605831@qq.com> Date: Thu, 18 Dec 2025 01:52:40 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AF=E7=94=A8=E7=89=88=E6=9C=ACv1.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .cursor | 1 + .gitattributes | 2 + .gitignore | 190 ++++++++++ .gitmodules | 6 + Convention | 1 + cli.py | 355 ++++++++++++++++++ web_server.py | 955 +++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 1510 insertions(+) create mode 160000 .cursor create mode 100644 .gitattributes create mode 100644 .gitignore create mode 100644 .gitmodules create mode 160000 Convention create mode 100644 cli.py create mode 100644 web_server.py diff --git a/.cursor b/.cursor new file mode 160000 index 0000000..67480b7 --- /dev/null +++ b/.cursor @@ -0,0 +1 @@ +Subproject commit 67480b7ec270ea5864d3d2a723e7d3cc94fd2c0a diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d0f2fb9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,190 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor.`.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore +# IDE +.vscode/ + +# Database +liubai_web.pid +Assets/ + +# StreamingAssets +StreamingAssets/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..7f41c75 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule ".cursor"] + path = .cursor + url = http://gitea.liubai.site/ninemine/.cursor.git +[submodule "Convention"] + path = Convention + url = http://gitea.liubai.site/ninemine/Convention-Python.git diff --git a/Convention b/Convention new file mode 160000 index 0000000..ad17b90 --- /dev/null +++ b/Convention @@ -0,0 +1 @@ +Subproject commit ad17b905c4889facecb3fd29e3a8f73d4aae2813 diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..43fae44 --- /dev/null +++ b/cli.py @@ -0,0 +1,355 @@ + +from Convention.Runtime.GlobalConfig import * +import asyncio +from datetime import datetime +from llama_index.llms.ollama import Ollama +from llama_index.core.chat_engine import SimpleChatEngine +from llama_index.core.chat_engine.types import StreamingAgentChatResponse +from llama_index.core import Settings +from Convention.Runtime.File import ToolFile +import requests +import subprocess +import wave +import io +import numpy as np +import pyaudio + +chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S") + + +config = ProjectConfig() +OLLAMA_URL = config.FindItem("ollama_url", "http://localhost:11434") +OLLAMA_MODEL = config.FindItem("ollama_model", "gemma3:4b") +RESPONSE_TIMEOUT = config.FindItem("response_timeout", 60) +TEMPERATURE = config.FindItem("temperature", 1.3) +MAX_CONTENT_LENGTH = config.FindItem("max_content_length", None) +SYSTEM_PROMPT_PATH = config.FindItem("system_prompt_path", None) +AUTO_SPEAK_WAIT_SECOND = config.FindItem("auto_speak_wait_second", 15.0) +TTS_SERVER_URL = config.FindItem("tts_server_url", "http://localhost:43400") +TTS_PROMPT_TEXT = config.FindItem("tts_prompt_text", None) +TTS_PROMPT_WAV_PATH = config.FindItem("tts_prompt_wav_path", None) +TTS_SPEAKER_ID = config.FindItem("tts_speaker_id", "tts_speaker") +SEED = config.FindItem("seed", 0) +STREAM_ENABLE = config.FindItem("stream_enable", False) +VERBOSE = config.FindItem("verbose", False) +if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_URL: {OLLAMA_URL}") + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_MODEL: {OLLAMA_MODEL}") + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"RESPONSE_TIMEOUT: {RESPONSE_TIMEOUT}") + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"TEMPERATURE: {TEMPERATURE}") + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}") + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"SYSTEM_PROMPT_PATH: {SYSTEM_PROMPT_PATH}") + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"AUTO_SPEAK_WAIT_SECOND: {AUTO_SPEAK_WAIT_SECOND}") +temp_dir = config.GetFile("temp")|chat_start_id|None + +ollama_llm_config = { + "model": OLLAMA_MODEL, + "base_url": OLLAMA_URL, + "request_timeout": RESPONSE_TIMEOUT, + "temperature": TEMPERATURE, +} +chat_engine_config = {} + +if MAX_CONTENT_LENGTH is not None: + ollama_llm_config["max_content_length"] = MAX_CONTENT_LENGTH + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}") + +if SYSTEM_PROMPT_PATH is not None: + system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText() + chat_engine_config["system_prompt"] = system_prompt + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"system_prompt: {system_prompt}") + +config.SaveProperties() + +def save_vocal_data(data:bytes) -> ToolFile: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + filename = f"{timestamp}.wav" + file = temp_dir|filename + file.MustExistsPath() + file.SaveAsBinary(data) + return file + +audio_play_queue: asyncio.Queue[bytes] = asyncio.Queue() +_pyaudio_instance: pyaudio.PyAudio | None = None +_pyaudio_stream: pyaudio.Stream | None = None +_current_sample_rate: int | None = None + + +# TOCHECK +def parse_wav_chunk(wav_bytes: bytes) -> tuple[np.ndarray, int]: + """ + 解析WAV数据,返回音频数组和采样率 + """ + wav_file = wave.open(io.BytesIO(wav_bytes)) + sample_rate = wav_file.getframerate() + n_channels = wav_file.getnchannels() + sample_width = wav_file.getsampwidth() + n_frames = wav_file.getnframes() + audio_bytes = wav_file.readframes(n_frames) + wav_file.close() + + if sample_width == 2: + audio_data = np.frombuffer(audio_bytes, dtype=np.int16) + elif sample_width == 4: + audio_data_32 = np.frombuffer(audio_bytes, dtype=np.int32) + max_val = np.abs(audio_data_32).max() + if max_val > 0: + audio_data = (audio_data_32 / max_val * 32767).astype(np.int16) + else: + audio_data = np.zeros(len(audio_data_32), dtype=np.int16) + else: + raise ValueError(f"Unsupported sample width: {sample_width}") + + if n_channels == 2: + audio_data = audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16) + + audio_data = np.clip(audio_data, -32768, 32767).astype(np.int16) + return audio_data, sample_rate + + +# TOCHECK +def play_audio_chunk(audio_data: np.ndarray, sample_rate: int) -> None: + """ + 使用PyAudio播放音频数组 + """ + global _pyaudio_instance, _pyaudio_stream, _current_sample_rate + + if _pyaudio_instance is None or _current_sample_rate != sample_rate: + if _pyaudio_stream is not None: + try: + _pyaudio_stream.stop_stream() + _pyaudio_stream.close() + except Exception: + pass + if _pyaudio_instance is not None: + try: + _pyaudio_instance.terminate() + except Exception: + pass + + frames_per_buffer = max(int(sample_rate * 0.02), 256) + _pyaudio_instance = pyaudio.PyAudio() + _pyaudio_stream = _pyaudio_instance.open( + format=pyaudio.paInt16, + channels=1, + rate=sample_rate, + output=True, + frames_per_buffer=frames_per_buffer + ) + _current_sample_rate = sample_rate + + if audio_data.dtype != np.int16: + max_val = np.abs(audio_data).max() + if max_val > 0: + audio_data = (audio_data / max_val * 32767).astype(np.int16) + else: + audio_data = np.zeros_like(audio_data, dtype=np.int16) + + chunk_size = 4096 + audio_bytes = audio_data.tobytes() + for i in range(0, len(audio_bytes), chunk_size): + chunk = audio_bytes[i:i + chunk_size] + if _pyaudio_stream is not None: + _pyaudio_stream.write(chunk) + +# TOCHECK +def cleanup_audio() -> None: + """ + 释放PyAudio资源 + """ + global _pyaudio_instance, _pyaudio_stream, _current_sample_rate + if _pyaudio_stream is not None: + try: + _pyaudio_stream.stop_stream() + _pyaudio_stream.close() + except Exception: + pass + _pyaudio_stream = None + if _pyaudio_instance is not None: + try: + _pyaudio_instance.terminate() + except Exception: + pass + _pyaudio_instance = None + _current_sample_rate = None + +# TOCHECK +def play_audio_sync(audio_data: bytes) -> None: + if not audio_data: + return + audio_array, sample_rate = parse_wav_chunk(audio_data) + play_audio_chunk(audio_array, sample_rate) + +# TOCHECK +async def audio_player_worker(): + """ + 音频播放后台任务,确保音频按顺序播放 + """ + while True: + audio_data = await audio_play_queue.get() + if audio_data is None: + audio_play_queue.task_done() + break + try: + await asyncio.to_thread(play_audio_sync, audio_data) + await asyncio.to_thread(save_vocal_data, audio_data) + except Exception as exc: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"音频播放失败: {exc}") + finally: + audio_play_queue.task_done() + +# CHANGE TOCHECK +async def play_vocal(text:str) -> None: + if len(text) == 0 or not text: + return + tts_server_url = f"{TTS_SERVER_URL}/api/synthesis/sft" + # 准备请求数据 + header = { + "accept": "application/json" + } + data = { + 'text': text, + 'speaker_id': TTS_SPEAKER_ID, + 'stream': STREAM_ENABLE, + } + # 发送POST请求 + if STREAM_ENABLE: + response = requests.post(tts_server_url, data=data, stream=True, timeout=600, headers=header) + if response.status_code != 200: + raise Exception(f"语音合成失败: {response.status_code} - {response.text}") + wav_buffer = bytearray() + for chunk in response.iter_content(chunk_size=1024 * 256): + if not chunk: + continue + wav_buffer.extend(chunk) + while len(wav_buffer) > 12: + if wav_buffer[:4] != b'RIFF': + riff_pos = wav_buffer.find(b'RIFF', 1) + if riff_pos == -1: + wav_buffer.clear() + break + wav_buffer = wav_buffer[riff_pos:] + if len(wav_buffer) < 8: + break + file_size = int.from_bytes(wav_buffer[4:8], byteorder='little') + expected_size = file_size + 8 + if len(wav_buffer) < expected_size: + break + complete_wav = bytes(wav_buffer[:expected_size]) + del wav_buffer[:expected_size] + await audio_play_queue.put(complete_wav) + if wav_buffer: + leftover = bytes(wav_buffer) + try: + parse_wav_chunk(leftover) + await audio_play_queue.put(leftover) + except Exception: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, "剩余音频数据解析失败,已丢弃") + else: + response = requests.post(tts_server_url, data=data, timeout=600, headers=header) + if response.status_code == 200: + await audio_play_queue.put(response.content) + else: + raise Exception(f"语音合成失败: {response.status_code} - {response.text}") + +async def ainput(wait_seconds:float) -> str: + loop = asyncio.get_event_loop() + + def get_input(): + try: + return input("\n你: ") + except EOFError: + return "" + + input_task = loop.run_in_executor(None, get_input) + while wait_seconds > 0: + if input_task.done(): + return input_task.result() + await asyncio.sleep(0.5) + wait_seconds -= 0.5 + return "" + + +async def achat(engine:SimpleChatEngine,message:str) -> None: + user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)" + streaming_response: StreamingAgentChatResponse = await engine.astream_chat(user_message) + buffer_response = "" + + end_symbol = ['。', '?', '!'] + + # 实时输出流式文本 + async for chunk in streaming_response.async_response_gen(): + await asyncio.sleep(0.01) + print(chunk, end='', flush=True) + for ch in chunk: + buffer_response += ch + if len(buffer_response) > 20: + if ch in end_symbol: + await play_vocal(buffer_response.strip()) + buffer_response = "" + buffer_response = buffer_response.strip() + if len(buffer_response) > 0: + await play_vocal(buffer_response) + + +def add_speaker() -> None: + url = f"{TTS_SERVER_URL}/api/speakers/add" + headers = { + "accept": "application/json" + } + data = { + "speaker_id": TTS_SPEAKER_ID, + "prompt_text": TTS_PROMPT_TEXT, + "force_regenerate": True + } + with open(TTS_PROMPT_WAV_PATH, 'rb') as f: + extension = ToolFile(TTS_PROMPT_WAV_PATH).GetExtension().lower() + files = { + 'prompt_wav': (f'prompt.{extension}', f, f'audio/{extension}') + } + response = requests.post(url, data=data, files=files, headers=headers, timeout=600) + if response.status_code == 200: + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"音色可用: {response.text}") + else: + raise Exception(f"添加音色失败: {response.status_code} - {response.text}") + + +async def event_loop(engine:SimpleChatEngine) -> None: + add_speaker() + audio_player_task = asyncio.create_task(audio_player_worker()) + message = input("请开始对话: ") + wait_second = AUTO_SPEAK_WAIT_SECOND + try: + while message != "quit" and message != "exit": + PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='') + await achat(engine, message) + PrintColorful(ConsoleFrontColor.RESET,"") + message = await ainput(wait_second) + if not message: + wait_second = max(wait_second*1.5, 3600) + else: + wait_second = AUTO_SPEAK_WAIT_SECOND + finally: + await audio_play_queue.join() + await audio_play_queue.put(None) + await audio_player_task + cleanup_audio() + + +async def main(): + # Initialize + try: + ollama_llm = Ollama(**ollama_llm_config) + Settings.llm = ollama_llm + chat_engine = SimpleChatEngine.from_defaults(**chat_engine_config) + await event_loop(chat_engine) + except Exception as e: + config.Log("Error", f"Error: {e}") + return + finally: + cleanup_audio() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/web_server.py b/web_server.py new file mode 100644 index 0000000..003b364 --- /dev/null +++ b/web_server.py @@ -0,0 +1,955 @@ +from Convention.Runtime.GlobalConfig import * +import asyncio +from datetime import datetime +from llama_index.llms.ollama import Ollama +from llama_index.core.chat_engine import SimpleChatEngine +from llama_index.core.chat_engine.types import StreamingAgentChatResponse +from llama_index.core import Settings +from Convention.Runtime.File import ToolFile +import requests +import wave +import io +import base64 +import json +import time +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException +from fastapi.responses import HTMLResponse, FileResponse +from fastapi.staticfiles import StaticFiles +from contextlib import asynccontextmanager +from typing import Optional, Set +from pydantic import BaseModel + +chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S") + +# 全局变量 +config: Optional[ProjectConfig] = None +chat_engine: Optional[SimpleChatEngine] = None +connected_clients: Set[WebSocket] = set() +temp_dir: Optional[ToolFile] = None +last_message_time: float = 0.0 +auto_speak_task: Optional[asyncio.Task] = None +is_processing: bool = False # 标记是否正在处理消息 +current_wait_second: float = 15.0 # 当前等待间隔(动态调整) +last_user_message_time: float = 0.0 # 最后用户消息时间(用于检测用户回应) + +# 配置变量 +OLLAMA_URL: str = "http://localhost:11434" +OLLAMA_MODEL: str = "gemma3:4b" +RESPONSE_TIMEOUT: int = 60 +TEMPERATURE: float = 1.3 +MAX_CONTENT_LENGTH: Optional[int] = None +SYSTEM_PROMPT_PATH: Optional[str] = None +TTS_SERVER_URL: str = "http://localhost:43400" +TTS_PROMPT_TEXT: Optional[str] = None +TTS_PROMPT_WAV_PATH: Optional[str] = None +TTS_SPEAKER_ID: str = "tts_speaker" +STREAM_ENABLE: bool = False +VERBOSE: bool = False +AUTO_SPEAK_WAIT_SECOND: float = 15.0 + +def initialize_config(): + """初始化配置""" + global config, temp_dir, OLLAMA_URL, OLLAMA_MODEL, RESPONSE_TIMEOUT, TEMPERATURE + global MAX_CONTENT_LENGTH, SYSTEM_PROMPT_PATH, TTS_SERVER_URL, TTS_PROMPT_TEXT + global TTS_PROMPT_WAV_PATH, TTS_SPEAKER_ID, STREAM_ENABLE, VERBOSE, AUTO_SPEAK_WAIT_SECOND + + config = ProjectConfig() + OLLAMA_URL = config.FindItem("ollama_url", "http://localhost:11434") + OLLAMA_MODEL = config.FindItem("ollama_model", "gemma3:4b") + RESPONSE_TIMEOUT = config.FindItem("response_timeout", 60) + TEMPERATURE = config.FindItem("temperature", 1.3) + MAX_CONTENT_LENGTH = config.FindItem("max_content_length", None) + SYSTEM_PROMPT_PATH = config.FindItem("system_prompt_path", None) + TTS_SERVER_URL = config.FindItem("tts_server_url", "http://localhost:43400") + TTS_PROMPT_TEXT = config.FindItem("tts_prompt_text", None) + TTS_PROMPT_WAV_PATH = config.FindItem("tts_prompt_wav_path", None) + TTS_SPEAKER_ID = config.FindItem("tts_speaker_id", "tts_speaker") + STREAM_ENABLE = config.FindItem("stream_enable", False) + VERBOSE = config.FindItem("verbose", False) + AUTO_SPEAK_WAIT_SECOND = config.FindItem("auto_speak_wait_second", 15.0) + + temp_dir = config.GetFile("temp") | chat_start_id | None + + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"OLLAMA_URL: {OLLAMA_URL}") + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"OLLAMA_MODEL: {OLLAMA_MODEL}") + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"RESPONSE_TIMEOUT: {RESPONSE_TIMEOUT}") + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"TEMPERATURE: {TEMPERATURE}") + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"AUTO_SPEAK_WAIT_SECOND: {AUTO_SPEAK_WAIT_SECOND}") + + config.SaveProperties() + +def initialize_chat_engine(): + """初始化聊天引擎""" + global chat_engine + + ollama_llm_config = { + "model": OLLAMA_MODEL, + "base_url": OLLAMA_URL, + "request_timeout": RESPONSE_TIMEOUT, + "temperature": TEMPERATURE, + } + + chat_engine_config = {} + + if MAX_CONTENT_LENGTH is not None: + ollama_llm_config["max_content_length"] = MAX_CONTENT_LENGTH + + if SYSTEM_PROMPT_PATH is not None: + system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText() + chat_engine_config["system_prompt"] = system_prompt + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"system_prompt loaded") + + ollama_llm = Ollama(**ollama_llm_config) + Settings.llm = ollama_llm + chat_engine = SimpleChatEngine.from_defaults(**chat_engine_config) + + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "Chat engine initialized") + +def save_vocal_data(data: bytes) -> ToolFile: + """保存音频数据""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + filename = f"{timestamp}.wav" + file = temp_dir | filename + file.MustExistsPath() + file.SaveAsBinary(data) + return file + +async def generate_tts_audio(text: str) -> Optional[bytes]: + """生成TTS音频,返回音频字节数据""" + if len(text) == 0 or not text: + return None + + if TTS_PROMPT_WAV_PATH is None: + return None + + tts_server_url = f"{TTS_SERVER_URL}/api/synthesis/sft" + header = { + "accept": "application/json" + } + data = { + 'text': text, + 'speaker_id': TTS_SPEAKER_ID, + 'stream': STREAM_ENABLE, + } + + def _generate_sync(): + """同步生成TTS音频""" + try: + if STREAM_ENABLE: + response = requests.post(tts_server_url, data=data, stream=True, timeout=600, headers=header) + if response.status_code != 200: + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"TTS失败: {response.status_code} - {response.text}") + return None + + wav_buffer = bytearray() + for chunk in response.iter_content(chunk_size=1024 * 256): + if not chunk: + continue + wav_buffer.extend(chunk) + + if wav_buffer: + complete_wav = bytes(wav_buffer) + save_vocal_data(complete_wav) + return complete_wav + return None + else: + response = requests.post(tts_server_url, data=data, timeout=600, headers=header) + if response.status_code == 200: + audio_data = response.content + save_vocal_data(audio_data) + return audio_data + else: + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"TTS失败: {response.status_code} - {response.text}") + return None + except Exception as e: + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"TTS异常: {e}") + return None + + # 在线程池中执行同步操作 + return await asyncio.to_thread(_generate_sync) + +def add_speaker() -> None: + """添加TTS音色""" + if TTS_PROMPT_WAV_PATH is None or TTS_PROMPT_TEXT is None: + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "TTS音色配置不完整,跳过初始化") + return + + url = f"{TTS_SERVER_URL}/api/speakers/add" + headers = { + "accept": "application/json" + } + data = { + "speaker_id": TTS_SPEAKER_ID, + "prompt_text": TTS_PROMPT_TEXT, + "force_regenerate": True + } + + try: + with open(TTS_PROMPT_WAV_PATH, 'rb') as f: + extension = ToolFile(TTS_PROMPT_WAV_PATH).GetExtension().lower() + files = { + 'prompt_wav': (f'prompt.{extension}', f, f'audio/{extension}') + } + response = requests.post(url, data=data, files=files, headers=headers, timeout=600) + if response.status_code == 200: + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"音色可用: {response.text}") + else: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"添加音色失败: {response.status_code} - {response.text}") + except Exception as e: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"添加音色异常: {e}") + +async def auto_speak_task_func(): + """自动发言后台任务""" + global chat_engine, last_message_time, config, connected_clients, is_processing, current_wait_second, last_user_message_time + + # 初始化最后消息时间和等待间隔 + last_message_time = time.time() + last_user_message_time = time.time() + current_wait_second = AUTO_SPEAK_WAIT_SECOND + + while True: + # 使用动态等待间隔 + await asyncio.sleep(current_wait_second) + + if chat_engine is None or config is None: + continue + + current_time = time.time() + # 检查是否有客户端连接、超过等待时间且当前没有正在处理的消息 + # 使用 last_user_message_time 来判断是否超过等待时间(只考虑用户消息,不包括自动发言) + if (connected_clients and + not is_processing and + (current_time - last_user_message_time) >= current_wait_second): + try: + # 触发自动发言 + auto_message = "(没有人说话, 请延续发言或是寻找新的话题)" + + # 向所有连接的客户端发送自动发言 + for websocket in list(connected_clients): + try: + await handle_chat_stream(websocket, auto_message) + except Exception as e: + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"自动发言发送错误: {e}") + + # 等待一段时间让用户有机会回应(等待当前间隔的10%,但至少1秒) + check_wait_time = max(current_wait_second * 0.1, 1.0) + await asyncio.sleep(check_wait_time) + + # 检查是否有用户回应:比较 last_user_message_time 是否在自动发言后被更新 + time_after_check = time.time() + if (time_after_check - last_user_message_time) > check_wait_time + 2.0: + # 没有用户回应,逐渐增加等待间隔 + current_wait_second = min(current_wait_second * 1.5, 3600.0) + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"无用户回应,等待间隔调整为: {current_wait_second}秒") + # 如果有用户回应,等待间隔会在 handle_chat_stream 中重置 + + except Exception as e: + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"自动发言任务错误: {e}") + +# FastAPI应用 +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + global auto_speak_task, last_message_time + + # 启动时初始化 + initialize_config() + initialize_chat_engine() + add_speaker() + last_message_time = time.time() + last_user_message_time = time.time() + + # 启动自动发言任务 + auto_speak_task = asyncio.create_task(auto_speak_task_func()) + + yield + + # 关闭时清理 + if auto_speak_task and not auto_speak_task.done(): + auto_speak_task.cancel() + try: + await auto_speak_task + except asyncio.CancelledError: + pass + + if config: + config.SaveProperties() + +app = FastAPI(lifespan=lifespan) + +# WebSocket连接管理 +async def connect_client(websocket: WebSocket): + """添加客户端连接""" + await websocket.accept() + connected_clients.add(websocket) + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, f"客户端已连接,当前连接数: {len(connected_clients)}") + +async def disconnect_client(websocket: WebSocket): + """移除客户端连接""" + connected_clients.discard(websocket) + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"客户端已断开,当前连接数: {len(connected_clients)}") + +# API模型 +class ChatRequest(BaseModel): + message: str + +async def safe_send_json(websocket: WebSocket, data: dict) -> bool: + """安全地发送 JSON 消息,检查连接状态""" + if websocket not in connected_clients: + return False + try: + await websocket.send_json(data) + return True + except (WebSocketDisconnect, RuntimeError, ConnectionError): + if websocket in connected_clients: + await disconnect_client(websocket) + return False + except Exception as e: + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"发送消息错误: {e}") + return False + +async def generate_and_send_audio(websocket: WebSocket, text: str): + """生成音频并发送到客户端""" + try: + # 检查 WebSocket 是否仍然连接 + if websocket not in connected_clients: + return + + audio_data = await generate_tts_audio(text) + + # 再次检查连接状态(可能在生成音频期间断开) + if websocket not in connected_clients: + return + + if audio_data: + audio_base64 = base64.b64encode(audio_data).decode('utf-8') + await safe_send_json(websocket, { + "type": "audio", + "audio": audio_base64 + }) + except (WebSocketDisconnect, RuntimeError, ConnectionError) as e: + # 连接已关闭,忽略错误 + if websocket in connected_clients: + await disconnect_client(websocket) + except Exception as e: + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"音频生成错误: {e}") + +# WebSocket端点 +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + """WebSocket端点,处理流式聊天""" + await connect_client(websocket) + try: + while True: + # 接收客户端消息 + data = await websocket.receive_text() + message_data = json.loads(data) + + if message_data.get("type") == "chat": + message = message_data.get("message", "") + if not message: + await safe_send_json(websocket, { + "type": "error", + "message": "消息不能为空" + }) + continue + + # 处理聊天请求 + await handle_chat_stream(websocket, message) + elif message_data.get("type") == "ping": + await safe_send_json(websocket, {"type": "pong"}) + + except WebSocketDisconnect: + await disconnect_client(websocket) + except Exception as e: + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"WebSocket错误: {e}") + await disconnect_client(websocket) + +async def handle_chat_stream(websocket: WebSocket, message: str): + """处理流式聊天""" + global chat_engine, last_message_time, is_processing, current_wait_second, last_user_message_time + + if chat_engine is None: + await safe_send_json(websocket, { + "type": "error", + "message": "聊天引擎未初始化" + }) + return + + # 更新最后消息时间和处理状态 + last_message_time = time.time() + is_processing = True + + # 如果是用户消息(不是自动发言),更新用户消息时间并重置等待间隔 + if message != "(没有人说话, 请延续发言或是寻找新的话题)": + last_user_message_time = time.time() + current_wait_second = AUTO_SPEAK_WAIT_SECOND + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, f"用户发送消息,等待间隔重置为: {current_wait_second}秒") + + user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)" + + try: + streaming_response: StreamingAgentChatResponse = await chat_engine.astream_chat(user_message) + buffer_response = "" + end_symbol = ['。', '?', '!'] + audio_text_buffer = "" + + # 发送开始消息 + if not await safe_send_json(websocket, { + "type": "start", + "message": "" + }): + return + + # 流式输出 + async for chunk in streaming_response.async_response_gen(): + # 检查连接状态 + if websocket not in connected_clients: + break + + await asyncio.sleep(0.01) + buffer_response += chunk + + # 发送文本块 + if not await safe_send_json(websocket, { + "type": "chunk", + "message": chunk + }): + break + + # 检查是否需要生成音频 + for ch in chunk: + audio_text_buffer += ch + if len(audio_text_buffer) > 20: + if ch in end_symbol: + text_to_speak = audio_text_buffer.strip() + if text_to_speak: + # 异步生成音频(不阻塞流式输出) + asyncio.create_task(generate_and_send_audio(websocket, text_to_speak)) + audio_text_buffer = "" + + # 检查连接状态 + if websocket not in connected_clients: + return + + # 处理剩余文本 + if buffer_response.strip(): + # 发送完成消息 + await safe_send_json(websocket, { + "type": "complete", + "message": buffer_response.strip() + }) + + # 生成剩余音频 + if audio_text_buffer.strip(): + asyncio.create_task(generate_and_send_audio(websocket, audio_text_buffer.strip())) + else: + await safe_send_json(websocket, { + "type": "complete", + "message": "" + }) + + except (WebSocketDisconnect, RuntimeError, ConnectionError): + if websocket in connected_clients: + await disconnect_client(websocket) + except Exception as e: + if VERBOSE: + PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"聊天处理错误: {e}") + await safe_send_json(websocket, { + "type": "error", + "message": f"处理请求时发生错误: {str(e)}" + }) + finally: + # 重置处理状态 + is_processing = False + +# 静态文件服务 +@app.get("/", response_class=HTMLResponse) +async def read_root(): + """返回前端页面""" + return HTMLResponse(content=get_html_content()) + +def get_html_content() -> str: + """生成HTML内容""" + return """ + +
+ + +