694 lines
26 KiB
Python
694 lines
26 KiB
Python
|
||
from Convention.Runtime.GlobalConfig import *
|
||
import asyncio
|
||
from datetime import datetime
|
||
from llama_index.llms.ollama import Ollama
|
||
from llama_index.core.chat_engine import SimpleChatEngine
|
||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||
from llama_index.core.llms import ChatMessage
|
||
# try:
|
||
# from llama_index.core.llms.types import ImageBlock, TextBlock
|
||
# except ImportError:
|
||
# try:
|
||
# # 尝试其他可能的导入路径
|
||
# from llama_index.core import ImageBlock, TextBlock
|
||
# except ImportError:
|
||
# # 如果都失败,定义简单的占位类
|
||
# class ImageBlock:
|
||
# def __init__(self, base64_str=None, path=None):
|
||
# self.base64_str = base64_str
|
||
# self.path = path
|
||
#
|
||
# class TextBlock:
|
||
# def __init__(self, text=""):
|
||
# self.text = text
|
||
from llama_index.core import Settings
|
||
from Convention.Runtime.File import ToolFile
|
||
import requests
|
||
import subprocess
|
||
import wave
|
||
import io
|
||
import numpy as np
|
||
import pyaudio
|
||
import base64
|
||
import os
|
||
import re
|
||
from typing import Optional
|
||
from PIL import ImageGrab
|
||
from io import BytesIO
|
||
|
||
chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
|
||
config = ProjectConfig()
|
||
OLLAMA_URL = config.FindItem("ollama_url", "http://localhost:11434")
|
||
OLLAMA_MODEL = config.FindItem("ollama_model", "gemma3:4b")
|
||
RESPONSE_TIMEOUT = config.FindItem("response_timeout", 60)
|
||
TEMPERATURE = config.FindItem("temperature", 1.3)
|
||
MAX_CONTENT_LENGTH = config.FindItem("max_content_length", None)
|
||
SYSTEM_PROMPT_PATH = config.FindItem("system_prompt_path", None)
|
||
AUTO_SPEAK_WAIT_SECOND = config.FindItem("auto_speak_wait_second", 15.0)
|
||
TTS_ENABLE = config.FindItem("tts_enable", False)
|
||
TTS_SERVER_URL = config.FindItem("tts_server_url", "http://localhost:43400")
|
||
TTS_PROMPT_TEXT = config.FindItem("tts_prompt_text", None)
|
||
TTS_PROMPT_WAV_PATH = config.FindItem("tts_prompt_wav_path", None)
|
||
TTS_SPEAKER_ID = config.FindItem("tts_speaker_id", "tts_speaker")
|
||
SEED = config.FindItem("seed", 0)
|
||
STREAM_ENABLE = config.FindItem("stream_enable", False)
|
||
VERBOSE = config.FindItem("verbose", False)
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_URL: {OLLAMA_URL}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_MODEL: {OLLAMA_MODEL}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"RESPONSE_TIMEOUT: {RESPONSE_TIMEOUT}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"TEMPERATURE: {TEMPERATURE}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"SYSTEM_PROMPT_PATH: {SYSTEM_PROMPT_PATH}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"AUTO_SPEAK_WAIT_SECOND: {AUTO_SPEAK_WAIT_SECOND}")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"TTS_ENABLE: {TTS_ENABLE}")
|
||
temp_dir = config.GetFile("temp")|chat_start_id|None
|
||
|
||
ollama_llm_config = {
|
||
"model": OLLAMA_MODEL,
|
||
"base_url": OLLAMA_URL,
|
||
"request_timeout": RESPONSE_TIMEOUT,
|
||
"temperature": TEMPERATURE,
|
||
}
|
||
chat_engine_config = {}
|
||
|
||
if MAX_CONTENT_LENGTH is not None:
|
||
ollama_llm_config["max_content_length"] = MAX_CONTENT_LENGTH
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
|
||
|
||
if SYSTEM_PROMPT_PATH is not None:
|
||
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
|
||
chat_engine_config["system_prompt"] = system_prompt
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"system_prompt: {system_prompt}")
|
||
|
||
config.SaveProperties()
|
||
|
||
def save_vocal_data(data:bytes) -> ToolFile:
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||
filename = f"{timestamp}.wav"
|
||
file = temp_dir|filename
|
||
file.MustExistsPath()
|
||
file.SaveAsBinary(data)
|
||
return file
|
||
|
||
audio_play_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||
_pyaudio_instance: pyaudio.PyAudio | None = None
|
||
_pyaudio_stream: pyaudio.Stream | None = None
|
||
_current_sample_rate: int | None = None
|
||
|
||
|
||
# TOCHECK
|
||
def parse_wav_chunk(wav_bytes: bytes) -> tuple[np.ndarray, int]:
|
||
"""
|
||
解析WAV数据,返回音频数组和采样率
|
||
"""
|
||
wav_file = wave.open(io.BytesIO(wav_bytes))
|
||
sample_rate = wav_file.getframerate()
|
||
n_channels = wav_file.getnchannels()
|
||
sample_width = wav_file.getsampwidth()
|
||
n_frames = wav_file.getnframes()
|
||
audio_bytes = wav_file.readframes(n_frames)
|
||
wav_file.close()
|
||
|
||
if sample_width == 2:
|
||
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
|
||
elif sample_width == 4:
|
||
audio_data_32 = np.frombuffer(audio_bytes, dtype=np.int32)
|
||
max_val = np.abs(audio_data_32).max()
|
||
if max_val > 0:
|
||
audio_data = (audio_data_32 / max_val * 32767).astype(np.int16)
|
||
else:
|
||
audio_data = np.zeros(len(audio_data_32), dtype=np.int16)
|
||
else:
|
||
raise ValueError(f"Unsupported sample width: {sample_width}")
|
||
|
||
if n_channels == 2:
|
||
audio_data = audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16)
|
||
|
||
audio_data = np.clip(audio_data, -32768, 32767).astype(np.int16)
|
||
return audio_data, sample_rate
|
||
|
||
|
||
# TOCHECK
|
||
def play_audio_chunk(audio_data: np.ndarray, sample_rate: int) -> None:
|
||
"""
|
||
使用PyAudio播放音频数组
|
||
"""
|
||
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
|
||
|
||
if _pyaudio_instance is None or _current_sample_rate != sample_rate:
|
||
if _pyaudio_stream is not None:
|
||
try:
|
||
_pyaudio_stream.stop_stream()
|
||
_pyaudio_stream.close()
|
||
except Exception:
|
||
pass
|
||
if _pyaudio_instance is not None:
|
||
try:
|
||
_pyaudio_instance.terminate()
|
||
except Exception:
|
||
pass
|
||
|
||
frames_per_buffer = max(int(sample_rate * 0.02), 256)
|
||
_pyaudio_instance = pyaudio.PyAudio()
|
||
_pyaudio_stream = _pyaudio_instance.open(
|
||
format=pyaudio.paInt16,
|
||
channels=1,
|
||
rate=sample_rate,
|
||
output=True,
|
||
frames_per_buffer=frames_per_buffer
|
||
)
|
||
_current_sample_rate = sample_rate
|
||
|
||
if audio_data.dtype != np.int16:
|
||
max_val = np.abs(audio_data).max()
|
||
if max_val > 0:
|
||
audio_data = (audio_data / max_val * 32767).astype(np.int16)
|
||
else:
|
||
audio_data = np.zeros_like(audio_data, dtype=np.int16)
|
||
|
||
chunk_size = 4096
|
||
audio_bytes = audio_data.tobytes()
|
||
for i in range(0, len(audio_bytes), chunk_size):
|
||
chunk = audio_bytes[i:i + chunk_size]
|
||
if _pyaudio_stream is not None:
|
||
_pyaudio_stream.write(chunk)
|
||
|
||
# TOCHECK
|
||
def cleanup_audio() -> None:
|
||
"""
|
||
释放PyAudio资源
|
||
"""
|
||
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
|
||
if _pyaudio_stream is not None:
|
||
try:
|
||
_pyaudio_stream.stop_stream()
|
||
_pyaudio_stream.close()
|
||
except Exception:
|
||
pass
|
||
_pyaudio_stream = None
|
||
if _pyaudio_instance is not None:
|
||
try:
|
||
_pyaudio_instance.terminate()
|
||
except Exception:
|
||
pass
|
||
_pyaudio_instance = None
|
||
_current_sample_rate = None
|
||
|
||
# TOCHECK
|
||
def play_audio_sync(audio_data: bytes) -> None:
|
||
if not audio_data:
|
||
return
|
||
audio_array, sample_rate = parse_wav_chunk(audio_data)
|
||
play_audio_chunk(audio_array, sample_rate)
|
||
|
||
# TOCHECK
|
||
async def audio_player_worker():
|
||
"""
|
||
音频播放后台任务,确保音频按顺序播放
|
||
"""
|
||
while True:
|
||
audio_data = await audio_play_queue.get()
|
||
if audio_data is None:
|
||
audio_play_queue.task_done()
|
||
break
|
||
try:
|
||
await asyncio.to_thread(play_audio_sync, audio_data)
|
||
await asyncio.to_thread(save_vocal_data, audio_data)
|
||
except Exception as exc:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"音频播放失败: {exc}")
|
||
finally:
|
||
audio_play_queue.task_done()
|
||
|
||
# CHANGE TOCHECK
|
||
async def play_vocal(text:str) -> None:
|
||
if not TTS_ENABLE:
|
||
return
|
||
if len(text) == 0 or not text:
|
||
return
|
||
tts_server_url = f"{TTS_SERVER_URL}/api/synthesis/sft"
|
||
# 准备请求数据
|
||
header = {
|
||
"accept": "application/json"
|
||
}
|
||
data = {
|
||
'text': text,
|
||
'speaker_id': TTS_SPEAKER_ID,
|
||
'stream': STREAM_ENABLE,
|
||
}
|
||
# 发送POST请求
|
||
if STREAM_ENABLE:
|
||
response = requests.post(tts_server_url, data=data, stream=True, timeout=600, headers=header)
|
||
if response.status_code != 200:
|
||
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
|
||
wav_buffer = bytearray()
|
||
for chunk in response.iter_content(chunk_size=1024 * 256):
|
||
if not chunk:
|
||
continue
|
||
wav_buffer.extend(chunk)
|
||
while len(wav_buffer) > 12:
|
||
if wav_buffer[:4] != b'RIFF':
|
||
riff_pos = wav_buffer.find(b'RIFF', 1)
|
||
if riff_pos == -1:
|
||
wav_buffer.clear()
|
||
break
|
||
wav_buffer = wav_buffer[riff_pos:]
|
||
if len(wav_buffer) < 8:
|
||
break
|
||
file_size = int.from_bytes(wav_buffer[4:8], byteorder='little')
|
||
expected_size = file_size + 8
|
||
if len(wav_buffer) < expected_size:
|
||
break
|
||
complete_wav = bytes(wav_buffer[:expected_size])
|
||
del wav_buffer[:expected_size]
|
||
await audio_play_queue.put(complete_wav)
|
||
if wav_buffer:
|
||
leftover = bytes(wav_buffer)
|
||
try:
|
||
parse_wav_chunk(leftover)
|
||
await audio_play_queue.put(leftover)
|
||
except Exception:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, "剩余音频数据解析失败,已丢弃")
|
||
else:
|
||
response = requests.post(tts_server_url, data=data, timeout=600, headers=header)
|
||
if response.status_code == 200:
|
||
await audio_play_queue.put(response.content)
|
||
else:
|
||
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
|
||
|
||
async def ainput(wait_seconds:float) -> str:
|
||
loop = asyncio.get_event_loop()
|
||
|
||
def get_input():
|
||
try:
|
||
return input("\n你: ")
|
||
except EOFError:
|
||
return ""
|
||
|
||
input_task = loop.run_in_executor(None, get_input)
|
||
while wait_seconds > 0:
|
||
if input_task.done():
|
||
return input_task.result()
|
||
await asyncio.sleep(0.5)
|
||
wait_seconds -= 0.5
|
||
return ""
|
||
|
||
|
||
def extract_image_paths(text: str) -> list[str]:
|
||
"""
|
||
从文本中提取所有以 @ 开头的图片路径
|
||
|
||
Args:
|
||
text: 用户输入的文本
|
||
|
||
Returns:
|
||
图片路径列表
|
||
"""
|
||
# 匹配 @ 开头的路径,支持引号包裹的路径
|
||
pattern = r'@["\']?([^"\'\s]+\.(?:png|jpg|jpeg))["\']?'
|
||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||
valid_paths = []
|
||
for match in matches:
|
||
if os.path.isfile(match) and match.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||
valid_paths.append(match)
|
||
elif VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"图片路径不存在或格式不支持: {match}")
|
||
return valid_paths
|
||
|
||
def remove_image_paths_from_text(text: str, image_paths: list[str]) -> str:
|
||
"""
|
||
从文本中移除图片路径引用
|
||
|
||
Args:
|
||
text: 原始文本
|
||
image_paths: 图片路径列表
|
||
|
||
Returns:
|
||
清理后的文本
|
||
"""
|
||
result = text
|
||
for path in image_paths:
|
||
# 移除 @path 格式的引用
|
||
result = re.sub(rf'@["\']?{re.escape(path)}["\']?', '', result, flags=re.IGNORECASE)
|
||
# 清理多余的空格
|
||
result = re.sub(r'\s+', ' ', result).strip()
|
||
return result
|
||
|
||
def image_file_to_base64(image_path: str) -> str:
|
||
"""
|
||
将图片文件转换为 base64 编码
|
||
|
||
Args:
|
||
image_path: 图片文件路径
|
||
|
||
Returns:
|
||
base64 编码的字符串
|
||
"""
|
||
try:
|
||
with open(image_path, 'rb') as image_file:
|
||
image_data = image_file.read()
|
||
return base64.b64encode(image_data).decode('utf-8')
|
||
except Exception as e:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"读取图片文件失败 {image_path}: {e}")
|
||
return None
|
||
|
||
async def _ollama_stream_chat_with_image(llm: Ollama, messages: list, image_base64: str, end_symbol: list) -> str:
|
||
"""
|
||
直接调用 Ollama API 进行带图片的流式聊天
|
||
|
||
Args:
|
||
llm: Ollama LLM 实例
|
||
messages: 消息列表
|
||
image_base64: base64 编码的图片
|
||
end_symbol: 结束符号列表
|
||
"""
|
||
buffer_response = ""
|
||
import aiohttp
|
||
|
||
# 构建请求数据
|
||
url = f"{llm.base_url}/api/chat" if hasattr(llm, 'base_url') else f"{OLLAMA_URL}/api/chat"
|
||
model = llm.model if hasattr(llm, 'model') else OLLAMA_MODEL
|
||
|
||
# 构建消息格式(Ollama API 格式)
|
||
api_messages = []
|
||
for msg in messages:
|
||
api_msg = {
|
||
"role": msg.role if hasattr(msg, 'role') else "user",
|
||
"content": msg.content if hasattr(msg, 'content') else str(msg)
|
||
}
|
||
api_messages.append(api_msg)
|
||
|
||
# 如果有图片,添加到最后一条用户消息(当前用户消息)
|
||
if image_base64:
|
||
for i in range(len(api_messages) - 1, -1, -1):
|
||
if api_messages[i].get("role") == "user":
|
||
api_messages[i]["images"] = [image_base64]
|
||
break
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": api_messages,
|
||
"stream": True
|
||
}
|
||
|
||
if hasattr(llm, 'temperature') and llm.temperature:
|
||
payload["options"] = {"temperature": llm.temperature}
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, json=payload) as response:
|
||
if response.status != 200:
|
||
error_text = await response.text()
|
||
raise Exception(f"Ollama API 错误: {response.status} - {error_text}")
|
||
|
||
# Ollama 流式响应是按行返回的 JSON
|
||
async for line_bytes in response.content:
|
||
if not line_bytes:
|
||
continue
|
||
|
||
try:
|
||
import json
|
||
line_text = line_bytes.decode('utf-8').strip()
|
||
if not line_text:
|
||
continue
|
||
|
||
data = json.loads(line_text)
|
||
if 'message' in data and 'content' in data['message']:
|
||
chunk = data['message']['content']
|
||
if chunk:
|
||
await asyncio.sleep(0.01)
|
||
print(chunk, end='', flush=True)
|
||
for ch in chunk:
|
||
buffer_response += ch
|
||
if len(buffer_response) > 20:
|
||
if ch in end_symbol:
|
||
if TTS_ENABLE:
|
||
await play_vocal(buffer_response.strip())
|
||
buffer_response = ""
|
||
|
||
# 检查是否完成
|
||
if data.get('done', False):
|
||
break
|
||
except json.JSONDecodeError:
|
||
continue
|
||
except Exception as e:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"解析响应错误: {e}")
|
||
|
||
buffer_response = buffer_response.strip()
|
||
if len(buffer_response) > 0:
|
||
if TTS_ENABLE:
|
||
await play_vocal(buffer_response)
|
||
|
||
return buffer_response
|
||
except Exception as e:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"Ollama API 调用错误: {e}")
|
||
if VERBOSE:
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise
|
||
|
||
def capture_screenshot() -> str:
|
||
"""
|
||
截取当前屏幕并转换为 base64 编码
|
||
|
||
Returns:
|
||
base64 编码的截图字符串
|
||
"""
|
||
try:
|
||
# 使用 PIL ImageGrab 截取屏幕
|
||
screenshot = ImageGrab.grab()
|
||
buffered = BytesIO()
|
||
screenshot.save(buffered, format="PNG")
|
||
image_data = buffered.getvalue()
|
||
return base64.b64encode(image_data).decode('utf-8')
|
||
except Exception as e:
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"截图失败: {e}")
|
||
return None
|
||
|
||
async def achat(llm: Ollama, message: str, image_base64: Optional[str] = None, auto_screenshot: bool = False, conversation_history: Optional[list] = None) -> list:
|
||
"""
|
||
使用 Ollama LLM 进行多模态聊天
|
||
|
||
Args:
|
||
llm: Ollama LLM 实例
|
||
message: 文本消息
|
||
image_base64: 可选的 base64 编码图片(单个图片)
|
||
auto_screenshot: 如果没有提供图片,是否自动截图(用于 AI 自主发言)
|
||
"""
|
||
user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)"
|
||
is_auto_speak = user_message == "(没有人说话, 请延续发言或是寻找新的话题)"
|
||
|
||
# 处理图片:如果提供了图片就用提供的,否则如果需要自动截图就截图
|
||
if not image_base64 and (auto_screenshot or is_auto_speak):
|
||
# AI 自主发言时自动截图
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "AI 自主发言,自动截图...")
|
||
image_base64 = capture_screenshot()
|
||
if image_base64 and VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "截图成功")
|
||
|
||
# 如果有图片,确保 base64 字符串不包含 data URL 前缀
|
||
if image_base64:
|
||
if image_base64.startswith('data:image'):
|
||
# 移除 data:image/xxx;base64, 前缀
|
||
image_base64 = image_base64.split(',', 1)[1]
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "已添加图片到消息")
|
||
|
||
# 构建消息 - 使用 ChatMessage,但只传递文本内容
|
||
# 图片将通过 Ollama 的底层 API 传递
|
||
chat_message = ChatMessage(role="user", content=user_message)
|
||
|
||
# 构建消息列表 - 使用对话历史(如果提供)
|
||
if conversation_history is None:
|
||
conversation_history = []
|
||
|
||
# 构建完整的消息列表(包含历史记录)
|
||
messages = conversation_history.copy()
|
||
|
||
# 如果对话历史为空,且需要系统提示,添加 system 消息(只添加一次)
|
||
if len(messages) == 0 and SYSTEM_PROMPT_PATH is not None:
|
||
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
|
||
if system_prompt:
|
||
system_msg = ChatMessage(role="system", content=system_prompt)
|
||
messages.append(system_msg)
|
||
|
||
# 添加当前用户消息
|
||
messages.append(chat_message)
|
||
|
||
buffer_response = ""
|
||
end_symbol = ['。', '?', '!']
|
||
|
||
try:
|
||
# 如果有图片,需要直接调用 Ollama API,因为 llama-index 的封装可能不支持图片
|
||
if image_base64:
|
||
# 直接调用 Ollama 的流式 API
|
||
assistant_response = await _ollama_stream_chat_with_image(llm, messages, image_base64, end_symbol)
|
||
else:
|
||
# 使用流式聊天(无图片时)
|
||
streaming_response = await llm.astream_chat(messages)
|
||
|
||
# 实时输出流式文本
|
||
async for chunk in streaming_response.async_response_gen():
|
||
await asyncio.sleep(0.01)
|
||
print(chunk, end='', flush=True)
|
||
for ch in chunk:
|
||
buffer_response += ch
|
||
if len(buffer_response) > 20:
|
||
if ch in end_symbol:
|
||
if TTS_ENABLE:
|
||
await play_vocal(buffer_response.strip())
|
||
buffer_response = ""
|
||
|
||
assistant_response = buffer_response.strip()
|
||
if len(assistant_response) > 0:
|
||
if TTS_ENABLE:
|
||
await play_vocal(assistant_response)
|
||
|
||
# 更新对话历史:添加用户消息和助手响应
|
||
updated_history = messages.copy()
|
||
if assistant_response:
|
||
assistant_msg = ChatMessage(role="assistant", content=assistant_response)
|
||
updated_history.append(assistant_msg)
|
||
|
||
return updated_history
|
||
except Exception as e:
|
||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"聊天错误: {e}")
|
||
if VERBOSE:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return conversation_history if conversation_history else []
|
||
|
||
|
||
def add_speaker() -> None:
|
||
url = f"{TTS_SERVER_URL}/api/speakers/add"
|
||
headers = {
|
||
"accept": "application/json"
|
||
}
|
||
data = {
|
||
"speaker_id": TTS_SPEAKER_ID,
|
||
"prompt_text": TTS_PROMPT_TEXT,
|
||
"force_regenerate": True
|
||
}
|
||
with open(TTS_PROMPT_WAV_PATH, 'rb') as f:
|
||
extension = ToolFile(TTS_PROMPT_WAV_PATH).GetExtension().lower()
|
||
files = {
|
||
'prompt_wav': (f'prompt.{extension}', f, f'audio/{extension}')
|
||
}
|
||
response = requests.post(url, data=data, files=files, headers=headers, timeout=600)
|
||
if response.status_code == 200:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"音色可用: {response.text}")
|
||
else:
|
||
raise Exception(f"添加音色失败: {response.status_code} - {response.text}")
|
||
|
||
|
||
async def event_loop(llm: Ollama) -> None:
|
||
"""
|
||
事件循环,处理用户输入和AI响应
|
||
|
||
Args:
|
||
llm: Ollama LLM 实例
|
||
"""
|
||
if TTS_ENABLE:
|
||
add_speaker()
|
||
audio_player_task = asyncio.create_task(audio_player_worker())
|
||
else:
|
||
audio_player_task = None
|
||
|
||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "多模态聊天已启动(支持文本和图片输入)")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "提示:使用 @/path/to/image.png 来指定图片,否则将自动截图")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "支持的图片格式: .png, .jpg, .jpeg")
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "输入 'quit' 或 'exit' 退出\n")
|
||
|
||
# 维护对话历史,避免重复发送系统提示词
|
||
conversation_history = []
|
||
|
||
message = input("请开始对话: ")
|
||
wait_second = AUTO_SPEAK_WAIT_SECOND
|
||
try:
|
||
while message != "quit" and message != "exit":
|
||
image_base64 = None
|
||
|
||
# 检查用户输入中是否包含 @ 开头的图片路径
|
||
image_paths = extract_image_paths(message)
|
||
|
||
if image_paths:
|
||
# 如果找到图片路径,读取第一个图片文件
|
||
image_path = image_paths[0]
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"检测到图片路径: {image_path}")
|
||
image_base64 = image_file_to_base64(image_path)
|
||
# 从消息中移除图片路径引用
|
||
message = remove_image_paths_from_text(message, image_paths)
|
||
else:
|
||
# 如果没有指定图片路径,自动截图
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "未检测到图片路径,自动截图...")
|
||
image_base64 = capture_screenshot()
|
||
if image_base64 and VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "截图成功")
|
||
|
||
PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='')
|
||
conversation_history = await achat(llm, message, image_base64, False, conversation_history)
|
||
PrintColorful(ConsoleFrontColor.RESET,"")
|
||
|
||
# 等待用户输入
|
||
message = await ainput(wait_second)
|
||
if not message:
|
||
# 用户没有输入,触发 AI 自主发言(会自动截图)
|
||
wait_second = min(wait_second*1.5, 3600)
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"用户无输入,等待 {wait_second} 秒后 AI 自主发言...")
|
||
# 触发 AI 自主发言(会自动截图)
|
||
PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='')
|
||
conversation_history = await achat(llm, "", None, auto_screenshot=True, conversation_history=conversation_history)
|
||
PrintColorful(ConsoleFrontColor.RESET,"")
|
||
else:
|
||
wait_second = AUTO_SPEAK_WAIT_SECOND
|
||
finally:
|
||
if TTS_ENABLE and audio_player_task is not None:
|
||
await audio_play_queue.join()
|
||
await audio_play_queue.put(None)
|
||
await audio_player_task
|
||
cleanup_audio()
|
||
|
||
|
||
async def main():
|
||
# Initialize
|
||
try:
|
||
ollama_llm = Ollama(**ollama_llm_config)
|
||
Settings.llm = ollama_llm
|
||
|
||
# 如果有系统提示,设置到 LLM
|
||
if SYSTEM_PROMPT_PATH is not None:
|
||
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
|
||
if system_prompt and hasattr(ollama_llm, 'system_prompt'):
|
||
ollama_llm.system_prompt = system_prompt
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "系统提示已设置到 LLM")
|
||
|
||
if VERBOSE:
|
||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "Ollama LLM 已初始化(支持多模态)")
|
||
|
||
# 直接使用 Ollama LLM 而不是 SimpleChatEngine
|
||
await event_loop(ollama_llm)
|
||
except Exception as e:
|
||
config.Log("Error", f"Error: {e}")
|
||
if VERBOSE:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return
|
||
finally:
|
||
cleanup_audio()
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|