Files
TheVirtualOne/cli.py

694 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from Convention.Runtime.GlobalConfig import *
import asyncio
from datetime import datetime
from llama_index.llms.ollama import Ollama
from llama_index.core.chat_engine import SimpleChatEngine
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
from llama_index.core.llms import ChatMessage
# try:
# from llama_index.core.llms.types import ImageBlock, TextBlock
# except ImportError:
# try:
# # 尝试其他可能的导入路径
# from llama_index.core import ImageBlock, TextBlock
# except ImportError:
# # 如果都失败,定义简单的占位类
# class ImageBlock:
# def __init__(self, base64_str=None, path=None):
# self.base64_str = base64_str
# self.path = path
#
# class TextBlock:
# def __init__(self, text=""):
# self.text = text
from llama_index.core import Settings
from Convention.Runtime.File import ToolFile
import requests
import subprocess
import wave
import io
import numpy as np
import pyaudio
import base64
import os
import re
from typing import Optional
from PIL import ImageGrab
from io import BytesIO
chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S")
config = ProjectConfig()
OLLAMA_URL = config.FindItem("ollama_url", "http://localhost:11434")
OLLAMA_MODEL = config.FindItem("ollama_model", "gemma3:4b")
RESPONSE_TIMEOUT = config.FindItem("response_timeout", 60)
TEMPERATURE = config.FindItem("temperature", 1.3)
MAX_CONTENT_LENGTH = config.FindItem("max_content_length", None)
SYSTEM_PROMPT_PATH = config.FindItem("system_prompt_path", None)
AUTO_SPEAK_WAIT_SECOND = config.FindItem("auto_speak_wait_second", 15.0)
TTS_ENABLE = config.FindItem("tts_enable", False)
TTS_SERVER_URL = config.FindItem("tts_server_url", "http://localhost:43400")
TTS_PROMPT_TEXT = config.FindItem("tts_prompt_text", None)
TTS_PROMPT_WAV_PATH = config.FindItem("tts_prompt_wav_path", None)
TTS_SPEAKER_ID = config.FindItem("tts_speaker_id", "tts_speaker")
SEED = config.FindItem("seed", 0)
STREAM_ENABLE = config.FindItem("stream_enable", False)
VERBOSE = config.FindItem("verbose", False)
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_URL: {OLLAMA_URL}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"OLLAMA_MODEL: {OLLAMA_MODEL}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"RESPONSE_TIMEOUT: {RESPONSE_TIMEOUT}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"TEMPERATURE: {TEMPERATURE}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"SYSTEM_PROMPT_PATH: {SYSTEM_PROMPT_PATH}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"AUTO_SPEAK_WAIT_SECOND: {AUTO_SPEAK_WAIT_SECOND}")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"TTS_ENABLE: {TTS_ENABLE}")
temp_dir = config.GetFile("temp")|chat_start_id|None
ollama_llm_config = {
"model": OLLAMA_MODEL,
"base_url": OLLAMA_URL,
"request_timeout": RESPONSE_TIMEOUT,
"temperature": TEMPERATURE,
}
chat_engine_config = {}
if MAX_CONTENT_LENGTH is not None:
ollama_llm_config["max_content_length"] = MAX_CONTENT_LENGTH
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"MAX_CONTENT_LENGTH: {MAX_CONTENT_LENGTH}")
if SYSTEM_PROMPT_PATH is not None:
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
chat_engine_config["system_prompt"] = system_prompt
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX,f"system_prompt: {system_prompt}")
config.SaveProperties()
def save_vocal_data(data:bytes) -> ToolFile:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}.wav"
file = temp_dir|filename
file.MustExistsPath()
file.SaveAsBinary(data)
return file
audio_play_queue: asyncio.Queue[bytes] = asyncio.Queue()
_pyaudio_instance: pyaudio.PyAudio | None = None
_pyaudio_stream: pyaudio.Stream | None = None
_current_sample_rate: int | None = None
# TOCHECK
def parse_wav_chunk(wav_bytes: bytes) -> tuple[np.ndarray, int]:
"""
解析WAV数据返回音频数组和采样率
"""
wav_file = wave.open(io.BytesIO(wav_bytes))
sample_rate = wav_file.getframerate()
n_channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
n_frames = wav_file.getnframes()
audio_bytes = wav_file.readframes(n_frames)
wav_file.close()
if sample_width == 2:
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
elif sample_width == 4:
audio_data_32 = np.frombuffer(audio_bytes, dtype=np.int32)
max_val = np.abs(audio_data_32).max()
if max_val > 0:
audio_data = (audio_data_32 / max_val * 32767).astype(np.int16)
else:
audio_data = np.zeros(len(audio_data_32), dtype=np.int16)
else:
raise ValueError(f"Unsupported sample width: {sample_width}")
if n_channels == 2:
audio_data = audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16)
audio_data = np.clip(audio_data, -32768, 32767).astype(np.int16)
return audio_data, sample_rate
# TOCHECK
def play_audio_chunk(audio_data: np.ndarray, sample_rate: int) -> None:
"""
使用PyAudio播放音频数组
"""
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
if _pyaudio_instance is None or _current_sample_rate != sample_rate:
if _pyaudio_stream is not None:
try:
_pyaudio_stream.stop_stream()
_pyaudio_stream.close()
except Exception:
pass
if _pyaudio_instance is not None:
try:
_pyaudio_instance.terminate()
except Exception:
pass
frames_per_buffer = max(int(sample_rate * 0.02), 256)
_pyaudio_instance = pyaudio.PyAudio()
_pyaudio_stream = _pyaudio_instance.open(
format=pyaudio.paInt16,
channels=1,
rate=sample_rate,
output=True,
frames_per_buffer=frames_per_buffer
)
_current_sample_rate = sample_rate
if audio_data.dtype != np.int16:
max_val = np.abs(audio_data).max()
if max_val > 0:
audio_data = (audio_data / max_val * 32767).astype(np.int16)
else:
audio_data = np.zeros_like(audio_data, dtype=np.int16)
chunk_size = 4096
audio_bytes = audio_data.tobytes()
for i in range(0, len(audio_bytes), chunk_size):
chunk = audio_bytes[i:i + chunk_size]
if _pyaudio_stream is not None:
_pyaudio_stream.write(chunk)
# TOCHECK
def cleanup_audio() -> None:
"""
释放PyAudio资源
"""
global _pyaudio_instance, _pyaudio_stream, _current_sample_rate
if _pyaudio_stream is not None:
try:
_pyaudio_stream.stop_stream()
_pyaudio_stream.close()
except Exception:
pass
_pyaudio_stream = None
if _pyaudio_instance is not None:
try:
_pyaudio_instance.terminate()
except Exception:
pass
_pyaudio_instance = None
_current_sample_rate = None
# TOCHECK
def play_audio_sync(audio_data: bytes) -> None:
if not audio_data:
return
audio_array, sample_rate = parse_wav_chunk(audio_data)
play_audio_chunk(audio_array, sample_rate)
# TOCHECK
async def audio_player_worker():
"""
音频播放后台任务,确保音频按顺序播放
"""
while True:
audio_data = await audio_play_queue.get()
if audio_data is None:
audio_play_queue.task_done()
break
try:
await asyncio.to_thread(play_audio_sync, audio_data)
await asyncio.to_thread(save_vocal_data, audio_data)
except Exception as exc:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"音频播放失败: {exc}")
finally:
audio_play_queue.task_done()
# CHANGE TOCHECK
async def play_vocal(text:str) -> None:
if not TTS_ENABLE:
return
if len(text) == 0 or not text:
return
tts_server_url = f"{TTS_SERVER_URL}/api/synthesis/sft"
# 准备请求数据
header = {
"accept": "application/json"
}
data = {
'text': text,
'speaker_id': TTS_SPEAKER_ID,
'stream': STREAM_ENABLE,
}
# 发送POST请求
if STREAM_ENABLE:
response = requests.post(tts_server_url, data=data, stream=True, timeout=600, headers=header)
if response.status_code != 200:
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
wav_buffer = bytearray()
for chunk in response.iter_content(chunk_size=1024 * 256):
if not chunk:
continue
wav_buffer.extend(chunk)
while len(wav_buffer) > 12:
if wav_buffer[:4] != b'RIFF':
riff_pos = wav_buffer.find(b'RIFF', 1)
if riff_pos == -1:
wav_buffer.clear()
break
wav_buffer = wav_buffer[riff_pos:]
if len(wav_buffer) < 8:
break
file_size = int.from_bytes(wav_buffer[4:8], byteorder='little')
expected_size = file_size + 8
if len(wav_buffer) < expected_size:
break
complete_wav = bytes(wav_buffer[:expected_size])
del wav_buffer[:expected_size]
await audio_play_queue.put(complete_wav)
if wav_buffer:
leftover = bytes(wav_buffer)
try:
parse_wav_chunk(leftover)
await audio_play_queue.put(leftover)
except Exception:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, "剩余音频数据解析失败,已丢弃")
else:
response = requests.post(tts_server_url, data=data, timeout=600, headers=header)
if response.status_code == 200:
await audio_play_queue.put(response.content)
else:
raise Exception(f"语音合成失败: {response.status_code} - {response.text}")
async def ainput(wait_seconds:float) -> str:
loop = asyncio.get_event_loop()
def get_input():
try:
return input("\n你: ")
except EOFError:
return ""
input_task = loop.run_in_executor(None, get_input)
while wait_seconds > 0:
if input_task.done():
return input_task.result()
await asyncio.sleep(0.5)
wait_seconds -= 0.5
return ""
def extract_image_paths(text: str) -> list[str]:
"""
从文本中提取所有以 @ 开头的图片路径
Args:
text: 用户输入的文本
Returns:
图片路径列表
"""
# 匹配 @ 开头的路径,支持引号包裹的路径
pattern = r'@["\']?([^"\'\s]+\.(?:png|jpg|jpeg))["\']?'
matches = re.findall(pattern, text, re.IGNORECASE)
valid_paths = []
for match in matches:
if os.path.isfile(match) and match.lower().endswith(('.png', '.jpg', '.jpeg')):
valid_paths.append(match)
elif VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"图片路径不存在或格式不支持: {match}")
return valid_paths
def remove_image_paths_from_text(text: str, image_paths: list[str]) -> str:
"""
从文本中移除图片路径引用
Args:
text: 原始文本
image_paths: 图片路径列表
Returns:
清理后的文本
"""
result = text
for path in image_paths:
# 移除 @path 格式的引用
result = re.sub(rf'@["\']?{re.escape(path)}["\']?', '', result, flags=re.IGNORECASE)
# 清理多余的空格
result = re.sub(r'\s+', ' ', result).strip()
return result
def image_file_to_base64(image_path: str) -> str:
"""
将图片文件转换为 base64 编码
Args:
image_path: 图片文件路径
Returns:
base64 编码的字符串
"""
try:
with open(image_path, 'rb') as image_file:
image_data = image_file.read()
return base64.b64encode(image_data).decode('utf-8')
except Exception as e:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"读取图片文件失败 {image_path}: {e}")
return None
async def _ollama_stream_chat_with_image(llm: Ollama, messages: list, image_base64: str, end_symbol: list) -> str:
"""
直接调用 Ollama API 进行带图片的流式聊天
Args:
llm: Ollama LLM 实例
messages: 消息列表
image_base64: base64 编码的图片
end_symbol: 结束符号列表
"""
buffer_response = ""
import aiohttp
# 构建请求数据
url = f"{llm.base_url}/api/chat" if hasattr(llm, 'base_url') else f"{OLLAMA_URL}/api/chat"
model = llm.model if hasattr(llm, 'model') else OLLAMA_MODEL
# 构建消息格式Ollama API 格式)
api_messages = []
for msg in messages:
api_msg = {
"role": msg.role if hasattr(msg, 'role') else "user",
"content": msg.content if hasattr(msg, 'content') else str(msg)
}
api_messages.append(api_msg)
# 如果有图片,添加到最后一条用户消息(当前用户消息)
if image_base64:
for i in range(len(api_messages) - 1, -1, -1):
if api_messages[i].get("role") == "user":
api_messages[i]["images"] = [image_base64]
break
payload = {
"model": model,
"messages": api_messages,
"stream": True
}
if hasattr(llm, 'temperature') and llm.temperature:
payload["options"] = {"temperature": llm.temperature}
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"Ollama API 错误: {response.status} - {error_text}")
# Ollama 流式响应是按行返回的 JSON
async for line_bytes in response.content:
if not line_bytes:
continue
try:
import json
line_text = line_bytes.decode('utf-8').strip()
if not line_text:
continue
data = json.loads(line_text)
if 'message' in data and 'content' in data['message']:
chunk = data['message']['content']
if chunk:
await asyncio.sleep(0.01)
print(chunk, end='', flush=True)
for ch in chunk:
buffer_response += ch
if len(buffer_response) > 20:
if ch in end_symbol:
if TTS_ENABLE:
await play_vocal(buffer_response.strip())
buffer_response = ""
# 检查是否完成
if data.get('done', False):
break
except json.JSONDecodeError:
continue
except Exception as e:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"解析响应错误: {e}")
buffer_response = buffer_response.strip()
if len(buffer_response) > 0:
if TTS_ENABLE:
await play_vocal(buffer_response)
return buffer_response
except Exception as e:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"Ollama API 调用错误: {e}")
if VERBOSE:
import traceback
traceback.print_exc()
raise
def capture_screenshot() -> str:
"""
截取当前屏幕并转换为 base64 编码
Returns:
base64 编码的截图字符串
"""
try:
# 使用 PIL ImageGrab 截取屏幕
screenshot = ImageGrab.grab()
buffered = BytesIO()
screenshot.save(buffered, format="PNG")
image_data = buffered.getvalue()
return base64.b64encode(image_data).decode('utf-8')
except Exception as e:
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"截图失败: {e}")
return None
async def achat(llm: Ollama, message: str, image_base64: Optional[str] = None, auto_screenshot: bool = False, conversation_history: Optional[list] = None) -> list:
"""
使用 Ollama LLM 进行多模态聊天
Args:
llm: Ollama LLM 实例
message: 文本消息
image_base64: 可选的 base64 编码图片(单个图片)
auto_screenshot: 如果没有提供图片,是否自动截图(用于 AI 自主发言)
"""
user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)"
is_auto_speak = user_message == "(没有人说话, 请延续发言或是寻找新的话题)"
# 处理图片:如果提供了图片就用提供的,否则如果需要自动截图就截图
if not image_base64 and (auto_screenshot or is_auto_speak):
# AI 自主发言时自动截图
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "AI 自主发言,自动截图...")
image_base64 = capture_screenshot()
if image_base64 and VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "截图成功")
# 如果有图片,确保 base64 字符串不包含 data URL 前缀
if image_base64:
if image_base64.startswith('data:image'):
# 移除 data:image/xxx;base64, 前缀
image_base64 = image_base64.split(',', 1)[1]
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "已添加图片到消息")
# 构建消息 - 使用 ChatMessage但只传递文本内容
# 图片将通过 Ollama 的底层 API 传递
chat_message = ChatMessage(role="user", content=user_message)
# 构建消息列表 - 使用对话历史(如果提供)
if conversation_history is None:
conversation_history = []
# 构建完整的消息列表(包含历史记录)
messages = conversation_history.copy()
# 如果对话历史为空,且需要系统提示,添加 system 消息(只添加一次)
if len(messages) == 0 and SYSTEM_PROMPT_PATH is not None:
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
if system_prompt:
system_msg = ChatMessage(role="system", content=system_prompt)
messages.append(system_msg)
# 添加当前用户消息
messages.append(chat_message)
buffer_response = ""
end_symbol = ['', '', '']
try:
# 如果有图片,需要直接调用 Ollama API因为 llama-index 的封装可能不支持图片
if image_base64:
# 直接调用 Ollama 的流式 API
assistant_response = await _ollama_stream_chat_with_image(llm, messages, image_base64, end_symbol)
else:
# 使用流式聊天(无图片时)
streaming_response = await llm.astream_chat(messages)
# 实时输出流式文本
async for chunk in streaming_response.async_response_gen():
await asyncio.sleep(0.01)
print(chunk, end='', flush=True)
for ch in chunk:
buffer_response += ch
if len(buffer_response) > 20:
if ch in end_symbol:
if TTS_ENABLE:
await play_vocal(buffer_response.strip())
buffer_response = ""
assistant_response = buffer_response.strip()
if len(assistant_response) > 0:
if TTS_ENABLE:
await play_vocal(assistant_response)
# 更新对话历史:添加用户消息和助手响应
updated_history = messages.copy()
if assistant_response:
assistant_msg = ChatMessage(role="assistant", content=assistant_response)
updated_history.append(assistant_msg)
return updated_history
except Exception as e:
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"聊天错误: {e}")
if VERBOSE:
import traceback
traceback.print_exc()
return conversation_history if conversation_history else []
def add_speaker() -> None:
url = f"{TTS_SERVER_URL}/api/speakers/add"
headers = {
"accept": "application/json"
}
data = {
"speaker_id": TTS_SPEAKER_ID,
"prompt_text": TTS_PROMPT_TEXT,
"force_regenerate": True
}
with open(TTS_PROMPT_WAV_PATH, 'rb') as f:
extension = ToolFile(TTS_PROMPT_WAV_PATH).GetExtension().lower()
files = {
'prompt_wav': (f'prompt.{extension}', f, f'audio/{extension}')
}
response = requests.post(url, data=data, files=files, headers=headers, timeout=600)
if response.status_code == 200:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"音色可用: {response.text}")
else:
raise Exception(f"添加音色失败: {response.status_code} - {response.text}")
async def event_loop(llm: Ollama) -> None:
"""
事件循环处理用户输入和AI响应
Args:
llm: Ollama LLM 实例
"""
if TTS_ENABLE:
add_speaker()
audio_player_task = asyncio.create_task(audio_player_worker())
else:
audio_player_task = None
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "多模态聊天已启动(支持文本和图片输入)")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "提示:使用 @/path/to/image.png 来指定图片,否则将自动截图")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "支持的图片格式: .png, .jpg, .jpeg")
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "输入 'quit''exit' 退出\n")
# 维护对话历史,避免重复发送系统提示词
conversation_history = []
message = input("请开始对话: ")
wait_second = AUTO_SPEAK_WAIT_SECOND
try:
while message != "quit" and message != "exit":
image_base64 = None
# 检查用户输入中是否包含 @ 开头的图片路径
image_paths = extract_image_paths(message)
if image_paths:
# 如果找到图片路径,读取第一个图片文件
image_path = image_paths[0]
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"检测到图片路径: {image_path}")
image_base64 = image_file_to_base64(image_path)
# 从消息中移除图片路径引用
message = remove_image_paths_from_text(message, image_paths)
else:
# 如果没有指定图片路径,自动截图
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "未检测到图片路径,自动截图...")
image_base64 = capture_screenshot()
if image_base64 and VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "截图成功")
PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='')
conversation_history = await achat(llm, message, image_base64, False, conversation_history)
PrintColorful(ConsoleFrontColor.RESET,"")
# 等待用户输入
message = await ainput(wait_second)
if not message:
# 用户没有输入,触发 AI 自主发言(会自动截图)
wait_second = min(wait_second*1.5, 3600)
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"用户无输入,等待 {wait_second} 秒后 AI 自主发言...")
# 触发 AI 自主发言(会自动截图)
PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='')
conversation_history = await achat(llm, "", None, auto_screenshot=True, conversation_history=conversation_history)
PrintColorful(ConsoleFrontColor.RESET,"")
else:
wait_second = AUTO_SPEAK_WAIT_SECOND
finally:
if TTS_ENABLE and audio_player_task is not None:
await audio_play_queue.join()
await audio_play_queue.put(None)
await audio_player_task
cleanup_audio()
async def main():
# Initialize
try:
ollama_llm = Ollama(**ollama_llm_config)
Settings.llm = ollama_llm
# 如果有系统提示,设置到 LLM
if SYSTEM_PROMPT_PATH is not None:
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
if system_prompt and hasattr(ollama_llm, 'system_prompt'):
ollama_llm.system_prompt = system_prompt
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "系统提示已设置到 LLM")
if VERBOSE:
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "Ollama LLM 已初始化(支持多模态)")
# 直接使用 Ollama LLM 而不是 SimpleChatEngine
await event_loop(ollama_llm)
except Exception as e:
config.Log("Error", f"Error: {e}")
if VERBOSE:
import traceback
traceback.print_exc()
return
finally:
cleanup_audio()
if __name__ == "__main__":
asyncio.run(main())