|
@@ -0,0 +1,125 @@
|
|
|
|
+import asyncio
|
|
|
|
+
|
|
|
|
+from zhipuai import ZhipuAI
|
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
|
+from anon import PluginManager, Plugin
|
|
|
|
+from anon.event import MessageEvent
|
|
|
|
+from anon.event.message import GroupMessage
|
|
|
|
+from anon.logger import logger
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class GlmClient:
|
|
|
|
+ max_cut_length = 2000
|
|
|
|
+ cache: dict
|
|
|
|
+ system: str
|
|
|
|
+ _client: ZhipuAI
|
|
|
|
+ _executor: ThreadPoolExecutor
|
|
|
|
+ _loop: asyncio.AbstractEventLoop
|
|
|
|
+
|
|
|
|
+ def __init__(self) -> None:
|
|
|
|
+ self.cache = {}
|
|
|
|
+ self._client = ZhipuAI(
|
|
|
|
+ api_key="017d4586c9ae229bc83bd483d954ff45.B03rIwOhtdgL6mSC")
|
|
|
|
+ self._executor = ThreadPoolExecutor(max_workers=5)
|
|
|
|
+ self._loop = asyncio.get_running_loop()
|
|
|
|
+ self.system = "你是 BangDream 企划中 MyGO 乐队的吉他手千早爱音,日本女高中生。" \
|
|
|
|
+ + "和用户进行对话,你可以回答用户的问题,也可以和用户进行闲聊。"
|
|
|
|
+
|
|
|
|
+ def resp_wrapper(msg):
|
|
|
|
+ return self._client.chat.completions.create(
|
|
|
|
+ model="glm-4",
|
|
|
|
+ messages=msg,
|
|
|
|
+ top_p=0.7,
|
|
|
|
+ temperature=0.9,
|
|
|
|
+ max_tokens=1024
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ self.resp_wrapper = resp_wrapper
|
|
|
|
+ logger.info('GlmClient initialized')
|
|
|
|
+
|
|
|
|
+ async def generate_response(self, id: int, sentence: str) -> str:
|
|
|
|
+ message = [
|
|
|
|
+ {
|
|
|
|
+ "role": "system",
|
|
|
|
+ "content": self.system
|
|
|
|
+ }
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ if id in self.cache:
|
|
|
|
+ message = self.cache[id]
|
|
|
|
+ else:
|
|
|
|
+ self.cache[id] = message
|
|
|
|
+
|
|
|
|
+ message.append({
|
|
|
|
+ "role": "user",
|
|
|
|
+ "content": sentence
|
|
|
|
+ })
|
|
|
|
+ res = ''
|
|
|
|
+ del_flag = False
|
|
|
|
+ try:
|
|
|
|
+ response = await self._loop.run_in_executor(self._executor, self.resp_wrapper, message)
|
|
|
|
+ res = response.choices[0].message.content
|
|
|
|
+ message.append({
|
|
|
|
+ "role": "assistant",
|
|
|
|
+ "content": res
|
|
|
|
+ })
|
|
|
|
+ del_flag = sum([len(i['content'])
|
|
|
|
+ for i in message]) > self.max_cut_length
|
|
|
|
+ except Exception as e:
|
|
|
|
+ res = f'Error: {e}'
|
|
|
|
+ del_flag = True
|
|
|
|
+
|
|
|
|
+ if del_flag:
|
|
|
|
+ logger.info(f'GlmClient: drop session, id: {id}')
|
|
|
|
+ del self.cache[id]
|
|
|
|
+ res += '\n\n-----Session Cut off-----'
|
|
|
|
+ return res
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class GlmPlugin(Plugin):
|
|
|
|
+ _client: GlmClient
|
|
|
|
+
|
|
|
|
+ async def on_load(self):
|
|
|
|
+ self._client = GlmClient()
|
|
|
|
+
|
|
|
|
+ async def on_event(self, event: MessageEvent):
|
|
|
|
+ rid = event.sender.user_id
|
|
|
|
+ if isinstance(event, GroupMessage):
|
|
|
|
+ rid = event.gid
|
|
|
|
+
|
|
|
|
+ if event.raw.startswith('glm'):
|
|
|
|
+ query = event.raw[3:].strip()
|
|
|
|
+ logger.info(f'GlmQuery: {query}')
|
|
|
|
+ await event.reply(await self._client.generate_response(rid, query))
|
|
|
|
+ if event.raw.startswith('glset'):
|
|
|
|
+ cmds = event.raw[5:].strip().split(' ', 1)
|
|
|
|
+ cmd = cmds[0]
|
|
|
|
+ if len(cmds) > 1:
|
|
|
|
+ args = cmds[1].split(' ')
|
|
|
|
+ else:
|
|
|
|
+ args = []
|
|
|
|
+
|
|
|
|
+ logger.info(f'GlmSet: {cmd} {args}')
|
|
|
|
+ if cmd == 'help':
|
|
|
|
+ await event.reply(
|
|
|
|
+ 'glset commands:\n'
|
|
|
|
+ ' - curlen [int]: set max cut length\n'
|
|
|
|
+ ' - s/sys [str]: set system prompt\n'
|
|
|
|
+ ' - g/sys: get system prompt\n'
|
|
|
|
+ ' - clear: clear cache'
|
|
|
|
+ )
|
|
|
|
+ if cmd == 'cutlen':
|
|
|
|
+ self._client.max_cut_length = int(args[0])
|
|
|
|
+ await event.reply(f'max_cut_length set to {args[0]}')
|
|
|
|
+ if cmd == 's/sys':
|
|
|
|
+ self._client.system = ' '.join(args)
|
|
|
|
+ self._client.cache.clear()
|
|
|
|
+ await event.reply(f'system prompt set to: {" ".join(args)}')
|
|
|
|
+ if cmd == 'g/sys':
|
|
|
|
+ await event.reply(f'system prompt: {self._client.system}')
|
|
|
|
+ if cmd == 'clear':
|
|
|
|
+ self._client.cache.clear()
|
|
|
|
+ await event.reply('cache cleared')
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+PluginManager().register_plugin(GlmPlugin([MessageEvent]))
|