jerrita 9 月之前
父節點
當前提交
5d7c41bfe0
共有 4 個文件被更改,包括 139 次插入0 次删除
  1. 1 0
      main.py
  2. 125 0
      plugins/anon/glm.py
  3. 1 0
      requirements.txt
  4. 12 0
      test.py

+ 1 - 0
main.py

@@ -10,6 +10,7 @@ if __name__ == '__main__':
         'plugins.anon.base',
         'plugins.anon.muri',
         'plugins.anon.turing',
+        'plugins.anon.glm',
         'plugins.corpus.bang'
     ])
     anon.loop()

+ 125 - 0
plugins/anon/glm.py

@@ -0,0 +1,125 @@
+import asyncio
+
+from zhipuai import ZhipuAI
+from concurrent.futures import ThreadPoolExecutor
+from anon import PluginManager, Plugin
+from anon.event import MessageEvent
+from anon.event.message import GroupMessage
+from anon.logger import logger
+
+
+class GlmClient:
+    max_cut_length = 2000
+    cache: dict
+    system: str
+    _client: ZhipuAI
+    _executor: ThreadPoolExecutor
+    _loop: asyncio.AbstractEventLoop
+
+    def __init__(self) -> None:
+        self.cache = {}
+        self._client = ZhipuAI(
+            api_key="017d4586c9ae229bc83bd483d954ff45.B03rIwOhtdgL6mSC")
+        self._executor = ThreadPoolExecutor(max_workers=5)
+        self._loop = asyncio.get_running_loop()
+        self.system = "你是 BangDream 企划中 MyGO 乐队的吉他手千早爱音,日本女高中生。" \
+            + "和用户进行对话,你可以回答用户的问题,也可以和用户进行闲聊。"
+
+        def resp_wrapper(msg):
+            return self._client.chat.completions.create(
+                model="glm-4",
+                messages=msg,
+                top_p=0.7,
+                temperature=0.9,
+                max_tokens=1024
+            )
+
+        self.resp_wrapper = resp_wrapper
+        logger.info('GlmClient initialized')
+
+    async def generate_response(self, id: int, sentence: str) -> str:
+        message = [
+            {
+                "role": "system",
+                "content": self.system
+            }
+        ]
+
+        if id in self.cache:
+            message = self.cache[id]
+        else:
+            self.cache[id] = message
+
+        message.append({
+            "role": "user",
+            "content": sentence
+        })
+        res = ''
+        del_flag = False
+        try:
+            response = await self._loop.run_in_executor(self._executor, self.resp_wrapper, message)
+            res = response.choices[0].message.content
+            message.append({
+                "role": "assistant",
+                "content": res
+            })
+            del_flag = sum([len(i['content'])
+                           for i in message]) > self.max_cut_length
+        except Exception as e:
+            res = f'Error: {e}'
+            del_flag = True
+
+        if del_flag:
+            logger.info(f'GlmClient: drop session, id: {id}')
+            del self.cache[id]
+            res += '\n\n-----Session Cut off-----'
+        return res
+
+
+class GlmPlugin(Plugin):
+    _client: GlmClient
+
+    async def on_load(self):
+        self._client = GlmClient()
+
+    async def on_event(self, event: MessageEvent):
+        rid = event.sender.user_id
+        if isinstance(event, GroupMessage):
+            rid = event.gid
+
+        if event.raw.startswith('glm'):
+            query = event.raw[3:].strip()
+            logger.info(f'GlmQuery: {query}')
+            await event.reply(await self._client.generate_response(rid, query))
+        if event.raw.startswith('glset'):
+            cmds = event.raw[5:].strip().split(' ', 1)
+            cmd = cmds[0]
+            if len(cmds) > 1:
+                args = cmds[1].split(' ')
+            else:
+                args = []
+
+            logger.info(f'GlmSet: {cmd} {args}')
+            if cmd == 'help':
+                await event.reply(
+                    'glset commands:\n'
+                    '   - curlen [int]: set max cut length\n'
+                    '   - s/sys [str]: set system prompt\n'
+                    '   - g/sys: get system prompt\n'
+                    '   - clear: clear cache'
+                )
+            if cmd == 'cutlen':
+                self._client.max_cut_length = int(args[0])
+                await event.reply(f'max_cut_length set to {args[0]}')
+            if cmd == 's/sys':
+                self._client.system = ' '.join(args)
+                self._client.cache.clear()
+                await event.reply(f'system prompt set to: {" ".join(args)}')
+            if cmd == 'g/sys':
+                await event.reply(f'system prompt: {self._client.system}')
+            if cmd == 'clear':
+                self._client.cache.clear()
+                await event.reply('cache cleared')
+
+
+PluginManager().register_plugin(GlmPlugin([MessageEvent]))

+ 1 - 0
requirements.txt

@@ -0,0 +1 @@
+zhipuai~=2.0.1

+ 12 - 0
test.py

@@ -0,0 +1,12 @@
+from anon import Bot
+from anon.logger import logger
+import logging
+
+logger.setLevel(logging.DEBUG)
+
+if __name__ == '__main__':
+    anon = Bot('192.168.5.15:5800', '114514')
+    anon.register_plugins([
+        'plugins.anon.glm'
+    ])
+    anon.loop()