cli已经可以接受图片输入, 但是有提示词的情况下存在问题
This commit is contained in:
351
cli.py
351
cli.py
@@ -5,6 +5,23 @@ from datetime import datetime
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.core.chat_engine import SimpleChatEngine
|
||||
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
|
||||
from llama_index.core.llms import ChatMessage
|
||||
try:
|
||||
from llama_index.core.llms.types import ImageBlock, TextBlock
|
||||
except ImportError:
|
||||
try:
|
||||
# 尝试其他可能的导入路径
|
||||
from llama_index.core import ImageBlock, TextBlock
|
||||
except ImportError:
|
||||
# 如果都失败,定义简单的占位类
|
||||
class ImageBlock:
|
||||
def __init__(self, base64_str=None, path=None):
|
||||
self.base64_str = base64_str
|
||||
self.path = path
|
||||
|
||||
class TextBlock:
|
||||
def __init__(self, text=""):
|
||||
self.text = text
|
||||
from llama_index.core import Settings
|
||||
from Convention.Runtime.File import ToolFile
|
||||
import requests
|
||||
@@ -13,6 +30,12 @@ import wave
|
||||
import io
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
from PIL import ImageGrab
|
||||
from io import BytesIO
|
||||
|
||||
chat_start_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
@@ -276,28 +299,253 @@ async def ainput(wait_seconds:float) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
async def achat(engine:SimpleChatEngine,message:str) -> None:
|
||||
user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)"
|
||||
streaming_response: StreamingAgentChatResponse = await engine.astream_chat(user_message)
|
||||
def extract_image_paths(text: str) -> list[str]:
|
||||
"""
|
||||
从文本中提取所有以 @ 开头的图片路径
|
||||
|
||||
Args:
|
||||
text: 用户输入的文本
|
||||
|
||||
Returns:
|
||||
图片路径列表
|
||||
"""
|
||||
# 匹配 @ 开头的路径,支持引号包裹的路径
|
||||
pattern = r'@["\']?([^"\'\s]+\.(?:png|jpg|jpeg))["\']?'
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
valid_paths = []
|
||||
for match in matches:
|
||||
if os.path.isfile(match) and match.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||||
valid_paths.append(match)
|
||||
elif VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"图片路径不存在或格式不支持: {match}")
|
||||
return valid_paths
|
||||
|
||||
def remove_image_paths_from_text(text: str, image_paths: list[str]) -> str:
|
||||
"""
|
||||
从文本中移除图片路径引用
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
image_paths: 图片路径列表
|
||||
|
||||
Returns:
|
||||
清理后的文本
|
||||
"""
|
||||
result = text
|
||||
for path in image_paths:
|
||||
# 移除 @path 格式的引用
|
||||
result = re.sub(rf'@["\']?{re.escape(path)}["\']?', '', result, flags=re.IGNORECASE)
|
||||
# 清理多余的空格
|
||||
result = re.sub(r'\s+', ' ', result).strip()
|
||||
return result
|
||||
|
||||
def image_file_to_base64(image_path: str) -> str:
|
||||
"""
|
||||
将图片文件转换为 base64 编码
|
||||
|
||||
Args:
|
||||
image_path: 图片文件路径
|
||||
|
||||
Returns:
|
||||
base64 编码的字符串
|
||||
"""
|
||||
try:
|
||||
with open(image_path, 'rb') as image_file:
|
||||
image_data = image_file.read()
|
||||
return base64.b64encode(image_data).decode('utf-8')
|
||||
except Exception as e:
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"读取图片文件失败 {image_path}: {e}")
|
||||
return None
|
||||
|
||||
async def _ollama_stream_chat_with_image(llm: Ollama, messages: list, image_base64: str, end_symbol: list) -> None:
|
||||
"""
|
||||
直接调用 Ollama API 进行带图片的流式聊天
|
||||
|
||||
Args:
|
||||
llm: Ollama LLM 实例
|
||||
messages: 消息列表
|
||||
image_base64: base64 编码的图片
|
||||
end_symbol: 结束符号列表
|
||||
"""
|
||||
buffer_response = ""
|
||||
import aiohttp
|
||||
|
||||
# 构建请求数据
|
||||
url = f"{llm.base_url}/api/chat" if hasattr(llm, 'base_url') else f"{OLLAMA_URL}/api/chat"
|
||||
model = llm.model if hasattr(llm, 'model') else OLLAMA_MODEL
|
||||
|
||||
# 构建消息格式(Ollama API 格式)
|
||||
api_messages = []
|
||||
for msg in messages:
|
||||
api_msg = {
|
||||
"role": msg.role if hasattr(msg, 'role') else "user",
|
||||
"content": msg.content if hasattr(msg, 'content') else str(msg)
|
||||
}
|
||||
# 如果是第一条用户消息且有图片,添加图片
|
||||
if (hasattr(msg, 'role') and msg.role == "user") and image_base64 and len(api_messages) == 0:
|
||||
api_msg["images"] = [image_base64]
|
||||
api_messages.append(api_msg)
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": api_messages,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
if hasattr(llm, 'temperature') and llm.temperature:
|
||||
payload["options"] = {"temperature": llm.temperature}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"Ollama API 错误: {response.status} - {error_text}")
|
||||
|
||||
# Ollama 流式响应是按行返回的 JSON
|
||||
async for line_bytes in response.content:
|
||||
if not line_bytes:
|
||||
continue
|
||||
|
||||
try:
|
||||
import json
|
||||
line_text = line_bytes.decode('utf-8').strip()
|
||||
if not line_text:
|
||||
continue
|
||||
|
||||
data = json.loads(line_text)
|
||||
if 'message' in data and 'content' in data['message']:
|
||||
chunk = data['message']['content']
|
||||
if chunk:
|
||||
await asyncio.sleep(0.01)
|
||||
print(chunk, end='', flush=True)
|
||||
for ch in chunk:
|
||||
buffer_response += ch
|
||||
if len(buffer_response) > 20:
|
||||
if ch in end_symbol:
|
||||
if TTS_ENABLE:
|
||||
await play_vocal(buffer_response.strip())
|
||||
buffer_response = ""
|
||||
|
||||
# 检查是否完成
|
||||
if data.get('done', False):
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"解析响应错误: {e}")
|
||||
|
||||
buffer_response = buffer_response.strip()
|
||||
if len(buffer_response) > 0:
|
||||
if TTS_ENABLE:
|
||||
await play_vocal(buffer_response)
|
||||
except Exception as e:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"Ollama API 调用错误: {e}")
|
||||
if VERBOSE:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def capture_screenshot() -> str:
|
||||
"""
|
||||
截取当前屏幕并转换为 base64 编码
|
||||
|
||||
Returns:
|
||||
base64 编码的截图字符串
|
||||
"""
|
||||
try:
|
||||
# 使用 PIL ImageGrab 截取屏幕
|
||||
screenshot = ImageGrab.grab()
|
||||
buffered = BytesIO()
|
||||
screenshot.save(buffered, format="PNG")
|
||||
image_data = buffered.getvalue()
|
||||
return base64.b64encode(image_data).decode('utf-8')
|
||||
except Exception as e:
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"截图失败: {e}")
|
||||
return None
|
||||
|
||||
async def achat(llm: Ollama, message: str, image_base64: Optional[str] = None, auto_screenshot: bool = False) -> None:
|
||||
"""
|
||||
使用 Ollama LLM 进行多模态聊天
|
||||
|
||||
Args:
|
||||
llm: Ollama LLM 实例
|
||||
message: 文本消息
|
||||
image_base64: 可选的 base64 编码图片(单个图片)
|
||||
auto_screenshot: 如果没有提供图片,是否自动截图(用于 AI 自主发言)
|
||||
"""
|
||||
user_message = message if message not in [None, ""] else "(没有人说话, 请延续发言或是寻找新的话题)"
|
||||
is_auto_speak = user_message == "(没有人说话, 请延续发言或是寻找新的话题)"
|
||||
|
||||
# 处理图片:如果提供了图片就用提供的,否则如果需要自动截图就截图
|
||||
if not image_base64 and (auto_screenshot or is_auto_speak):
|
||||
# AI 自主发言时自动截图
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "AI 自主发言,自动截图...")
|
||||
image_base64 = capture_screenshot()
|
||||
if image_base64 and VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "截图成功")
|
||||
|
||||
# 如果有图片,确保 base64 字符串不包含 data URL 前缀
|
||||
if image_base64:
|
||||
if image_base64.startswith('data:image'):
|
||||
# 移除 data:image/xxx;base64, 前缀
|
||||
image_base64 = image_base64.split(',', 1)[1]
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "已添加图片到消息")
|
||||
|
||||
# 构建消息 - 使用 ChatMessage,但只传递文本内容
|
||||
# 图片将通过 Ollama 的底层 API 传递
|
||||
chat_message = ChatMessage(role="user", content=user_message)
|
||||
|
||||
# 构建消息列表
|
||||
messages = [chat_message]
|
||||
|
||||
# 如果有系统提示,添加到消息列表开头
|
||||
if SYSTEM_PROMPT_PATH is not None:
|
||||
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
|
||||
if system_prompt:
|
||||
# 将系统提示添加到用户消息中,因为 Ollama 可能不支持 system role
|
||||
user_message = f"{system_prompt}\n\n{user_message}"
|
||||
messages[0] = ChatMessage(role="user", content=user_message)
|
||||
|
||||
buffer_response = ""
|
||||
end_symbol = ['。', '?', '!']
|
||||
|
||||
# 实时输出流式文本
|
||||
async for chunk in streaming_response.async_response_gen():
|
||||
await asyncio.sleep(0.01)
|
||||
print(chunk, end='', flush=True)
|
||||
for ch in chunk:
|
||||
buffer_response += ch
|
||||
if len(buffer_response) > 20:
|
||||
if ch in end_symbol:
|
||||
if TTS_ENABLE:
|
||||
await play_vocal(buffer_response.strip())
|
||||
buffer_response = ""
|
||||
buffer_response = buffer_response.strip()
|
||||
if len(buffer_response) > 0:
|
||||
if TTS_ENABLE:
|
||||
await play_vocal(buffer_response)
|
||||
|
||||
try:
|
||||
# 如果有图片,需要直接调用 Ollama API,因为 llama-index 的封装可能不支持图片
|
||||
if image_base64:
|
||||
# 直接调用 Ollama 的流式 API
|
||||
await _ollama_stream_chat_with_image(llm, messages, image_base64, end_symbol)
|
||||
return
|
||||
|
||||
# 使用流式聊天(无图片时)
|
||||
streaming_response = await llm.astream_chat(messages)
|
||||
|
||||
# 实时输出流式文本
|
||||
async for chunk in streaming_response.async_response_gen():
|
||||
await asyncio.sleep(0.01)
|
||||
print(chunk, end='', flush=True)
|
||||
for ch in chunk:
|
||||
buffer_response += ch
|
||||
if len(buffer_response) > 20:
|
||||
if ch in end_symbol:
|
||||
if TTS_ENABLE:
|
||||
await play_vocal(buffer_response.strip())
|
||||
buffer_response = ""
|
||||
|
||||
buffer_response = buffer_response.strip()
|
||||
if len(buffer_response) > 0:
|
||||
if TTS_ENABLE:
|
||||
await play_vocal(buffer_response)
|
||||
except Exception as e:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTRED_EX, f"聊天错误: {e}")
|
||||
if VERBOSE:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def add_speaker() -> None:
|
||||
@@ -322,22 +570,64 @@ def add_speaker() -> None:
|
||||
raise Exception(f"添加音色失败: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
async def event_loop(engine:SimpleChatEngine) -> None:
|
||||
async def event_loop(llm: Ollama) -> None:
|
||||
"""
|
||||
事件循环,处理用户输入和AI响应
|
||||
|
||||
Args:
|
||||
llm: Ollama LLM 实例
|
||||
"""
|
||||
if TTS_ENABLE:
|
||||
add_speaker()
|
||||
audio_player_task = asyncio.create_task(audio_player_worker())
|
||||
else:
|
||||
audio_player_task = None
|
||||
|
||||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "多模态聊天已启动(支持文本和图片输入)")
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "提示:使用 @/path/to/image.png 来指定图片,否则将自动截图")
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "支持的图片格式: .png, .jpg, .jpeg")
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "输入 'quit' 或 'exit' 退出\n")
|
||||
|
||||
message = input("请开始对话: ")
|
||||
wait_second = AUTO_SPEAK_WAIT_SECOND
|
||||
try:
|
||||
while message != "quit" and message != "exit":
|
||||
image_base64 = None
|
||||
|
||||
# 检查用户输入中是否包含 @ 开头的图片路径
|
||||
image_paths = extract_image_paths(message)
|
||||
|
||||
if image_paths:
|
||||
# 如果找到图片路径,读取第一个图片文件
|
||||
image_path = image_paths[0]
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"检测到图片路径: {image_path}")
|
||||
image_base64 = image_file_to_base64(image_path)
|
||||
# 从消息中移除图片路径引用
|
||||
message = remove_image_paths_from_text(message, image_paths)
|
||||
else:
|
||||
# 如果没有指定图片路径,自动截图
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "未检测到图片路径,自动截图...")
|
||||
image_base64 = capture_screenshot()
|
||||
if image_base64 and VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "截图成功")
|
||||
|
||||
PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='')
|
||||
await achat(engine, message)
|
||||
await achat(llm, message, image_base64)
|
||||
PrintColorful(ConsoleFrontColor.RESET,"")
|
||||
|
||||
# 等待用户输入
|
||||
message = await ainput(wait_second)
|
||||
if not message:
|
||||
# 用户没有输入,触发 AI 自主发言(会自动截图)
|
||||
wait_second = max(wait_second*1.5, 3600)
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, f"用户无输入,等待 {wait_second} 秒后 AI 自主发言...")
|
||||
# 触发 AI 自主发言(会自动截图)
|
||||
PrintColorful(ConsoleFrontColor.GREEN, "AI: ", is_reset=False, end='')
|
||||
await achat(llm, "", None, auto_screenshot=True)
|
||||
PrintColorful(ConsoleFrontColor.RESET,"")
|
||||
else:
|
||||
wait_second = AUTO_SPEAK_WAIT_SECOND
|
||||
finally:
|
||||
@@ -353,10 +643,25 @@ async def main():
|
||||
try:
|
||||
ollama_llm = Ollama(**ollama_llm_config)
|
||||
Settings.llm = ollama_llm
|
||||
chat_engine = SimpleChatEngine.from_defaults(**chat_engine_config)
|
||||
await event_loop(chat_engine)
|
||||
|
||||
# 如果有系统提示,设置到 LLM
|
||||
if SYSTEM_PROMPT_PATH is not None:
|
||||
system_prompt = ToolFile(SYSTEM_PROMPT_PATH).LoadAsText()
|
||||
if system_prompt and hasattr(ollama_llm, 'system_prompt'):
|
||||
ollama_llm.system_prompt = system_prompt
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTYELLOW_EX, "系统提示已设置到 LLM")
|
||||
|
||||
if VERBOSE:
|
||||
PrintColorful(ConsoleFrontColor.LIGHTGREEN_EX, "Ollama LLM 已初始化(支持多模态)")
|
||||
|
||||
# 直接使用 Ollama LLM 而不是 SimpleChatEngine
|
||||
await event_loop(ollama_llm)
|
||||
except Exception as e:
|
||||
config.Log("Error", f"Error: {e}")
|
||||
if VERBOSE:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
finally:
|
||||
cleanup_audio()
|
||||
|
||||
Reference in New Issue
Block a user