diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000..0af19c5
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,13 @@
+[flake8]
+max-line-length = 176
+select = E303,W293,W291,W292,E305,E231,E302
+exclude =
+ .tox,
+ __pycache__,
+ *.pyc,
+ .env
+ venv/*
+ .venv/*
+ reports/*
+ dist/*
+ lib/*
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/1.bug.yml b/.github/ISSUE_TEMPLATE/1.bug.yml
new file mode 100644
index 0000000..2f762c0
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/1.bug.yml
@@ -0,0 +1,133 @@
+name: Bug report 🐛
+description: 项目运行中遇到的Bug或问题。
+labels: ['status: needs check']
+body:
+ - type: markdown
+ attributes:
+ value: |
+ ### ⚠️ 前置确认
+ 1. 网络能够访问openai接口
+ 2. python 已安装:版本在 3.7 ~ 3.10 之间
+ 3. `git pull` 拉取最新代码
+ 4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
+ 5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
+ 6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
+ - type: checkboxes
+ attributes:
+ label: 前置确认
+ options:
+ - label: 我确认我运行的是最新版本的代码,并且安装了所需的依赖,在[FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs)中也未找到类似问题。
+ required: true
+ - type: checkboxes
+ attributes:
+ label: ⚠️ 搜索issues中是否已存在类似问题
+ description: >
+ 请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索你的问题
+ 或相关日志的关键词来查找是否存在类似问题。
+ options:
+ - label: 我已经搜索过issues和disscussions,没有跟我遇到的问题相关的issue
+ required: true
+ - type: markdown
+ attributes:
+ value: |
+ 请在上方的`title`中填写你对你所遇到问题的简略总结,这将帮助其他人更好的找到相似问题,谢谢❤️。
+ - type: dropdown
+ attributes:
+ label: 操作系统类型?
+ description: >
+ 请选择你运行程序的操作系统类型。
+ options:
+ - Windows
+ - Linux
+ - MacOS
+ - Docker
+ - Railway
+ - Windows Subsystem for Linux (WSL)
+ - Other (请在问题中说明)
+ validations:
+ required: true
+ - type: dropdown
+ attributes:
+ label: 运行的python版本是?
+ description: |
+ 请选择你运行程序的`python`版本。
+ 注意:在`python 3.7`中,有部分可选依赖无法安装。
+ 经过长时间的观察,我们认为`python 3.8`是兼容性最好的版本。
+ `python 3.7`~`python 3.10`以外版本的issue,将视情况直接关闭。
+ options:
+ - python 3.7
+ - python 3.8
+ - python 3.9
+ - python 3.10
+ - other
+ validations:
+ required: true
+ - type: dropdown
+ attributes:
+ label: 使用的chatgpt-on-wechat版本是?
+ description: |
+ 请确保你使用的是 [releases](https://github.com/zhayujie/chatgpt-on-wechat/releases) 中的最新版本。
+ 如果你使用git, 请使用`git branch`命令来查看分支。
+ options:
+ - Latest Release
+ - Master (branch)
+ validations:
+ required: true
+ - type: dropdown
+ attributes:
+ label: 运行的`channel`类型是?
+ description: |
+ 请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
+ options:
+ - wx(个人微信, itchat)
+ - wxy(个人微信, wechaty)
+ - wechatmp(公众号, 订阅号)
+ - wechatmp_service(公众号, 服务号)
+ - terminal
+ - other
+ validations:
+ required: true
+ - type: textarea
+ attributes:
+ label: 复现步骤 🕹
+ description: |
+ **⚠️ 不能复现将会关闭issue.**
+ - type: textarea
+ attributes:
+ label: 问题描述 😯
+ description: 详细描述出现的问题,或提供有关截图。
+ - type: textarea
+ attributes:
+ label: 终端日志 📒
+ description: |
+ 在此处粘贴终端日志,可在主目录下`run.log`文件中找到,这会帮助我们更好的分析问题,注意隐去你的API key。
+ 如果在配置文件中加入`"debug": true`,打印出的日志会更有帮助。
+
+
+ 示例
+ ```log
+ [DEBUG][2023-04-16 00:23:22][plugin_manager.py:157] - Plugin SUMMARY triggered by event Event.ON_HANDLE_CONTEXT
+ [DEBUG][2023-04-16 00:23:22][main.py:221] - [Summary] on_handle_context. content: $总结前100条消息
+ [DEBUG][2023-04-16 00:23:24][main.py:240] - [Summary] limit: 100, duration: -1 seconds
+ [ERROR][2023-04-16 00:23:24][chat_channel.py:244] - Worker return exception: name 'start_date' is not defined
+ Traceback (most recent call last):
+ File "C:\ProgramData\Anaconda3\lib\concurrent\futures\thread.py", line 57, in run
+ result = self.fn(*self.args, **self.kwargs)
+ File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 132, in _handle
+ reply = self._generate_reply(context)
+ File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 142, in _generate_reply
+ e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
+ File "D:\project\chatgpt-on-wechat\plugins\plugin_manager.py", line 159, in emit_event
+ instance.handlers[e_context.event](e_context, *args, **kwargs)
+ File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 255, in on_handle_context
+ records = self._get_records(session_id, start_time, limit)
+ File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 96, in _get_records
+ c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", (session_id, start_date, limit))
+ NameError: name 'start_date' is not defined
+ [INFO][2023-04-16 00:23:36][app.py:14] - signal 2 received, exiting...
+ ```
+
+ value: |
+ ```log
+ <此处粘贴终端日志>
+ ```
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/2.feature.yml b/.github/ISSUE_TEMPLATE/2.feature.yml
new file mode 100644
index 0000000..bbf0888
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/2.feature.yml
@@ -0,0 +1,28 @@
+name: Feature request 🚀
+description: 提出你对项目的新想法或建议。
+labels: ['status: needs check']
+body:
+ - type: markdown
+ attributes:
+ value: |
+ 请在上方的`title`中填写简略总结,谢谢❤️。
+ - type: checkboxes
+ attributes:
+ label: ⚠️ 搜索是否存在类似issue
+ description: >
+ 请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索关键词查找是否存在相似issue。
+ options:
+ - label: 我已经搜索过issues和disscussions,没有发现相似issue
+ required: true
+ - type: textarea
+ attributes:
+ label: 总结
+ description: 描述feature的功能。
+ - type: textarea
+ attributes:
+ label: 举例
+ description: 提供聊天示例,草图或相关网址。
+ - type: textarea
+ attributes:
+ label: 动机
+ description: 描述你提出该feature的动机,比如没有这项feature对你的使用造成了怎样的影响。 请提供更详细的场景描述,这可能会帮助我们发现并提出更好的解决方案。
\ No newline at end of file
diff --git a/.github/workflows/deploy-image-arm.yml b/.github/workflows/deploy-image-arm.yml
new file mode 100644
index 0000000..9721add
--- /dev/null
+++ b/.github/workflows/deploy-image-arm.yml
@@ -0,0 +1,72 @@
+# This workflow uses actions that are not certified by GitHub.
+# They are provided by a third-party and are governed by
+# separate terms of service, privacy policy, and support
+# documentation.
+
+# GitHub recommends pinning actions to a commit SHA.
+# To get a newer version, you will need to update the SHA.
+# You can also reference a tag or branch, but the action may change without warning.
+
+name: Create and publish a Docker image
+
+on:
+ push:
+ branches: ['master']
+ create:
+env:
+ REGISTRY: ghcr.io
+ IMAGE_NAME: ${{ github.repository }}
+
+jobs:
+ build-and-push-image:
+ if: github.repository == 'zhayujie/chatgpt-on-wechat'
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ packages: write
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v3
+
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@v1
+
+ - name: Set up Docker Buildx
+ id: buildx
+ uses: docker/setup-buildx-action@v1
+
+ - name: Available platforms
+ run: echo ${{ steps.buildx.outputs.platforms }}
+
+ - name: Log in to the Container registry
+ uses: docker/login-action@v2
+ with:
+ registry: ${{ env.REGISTRY }}
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Extract metadata (tags, labels) for Docker
+ id: meta
+ uses: docker/metadata-action@v4
+ with:
+ images: |
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
+
+ - name: Build and push Docker image
+ uses: docker/build-push-action@v3
+ with:
+ context: .
+ push: true
+ file: ./docker/Dockerfile.latest
+ platforms: linux/arm64
+ tags: ${{ steps.meta.outputs.tags }}-arm64
+ labels: ${{ steps.meta.outputs.labels }}
+
+ - uses: actions/delete-package-versions@v4
+ with:
+ package-name: 'chatgpt-on-wechat'
+ package-type: 'container'
+ min-versions-to-keep: 10
+ delete-only-untagged-versions: 'true'
+ token: ${{ secrets.GITHUB_TOKEN }}
\ No newline at end of file
diff --git a/.github/workflows/deploy-image.yml b/.github/workflows/deploy-image.yml
new file mode 100644
index 0000000..a30b77f
--- /dev/null
+++ b/.github/workflows/deploy-image.yml
@@ -0,0 +1,68 @@
+# This workflow uses actions that are not certified by GitHub.
+# They are provided by a third-party and are governed by
+# separate terms of service, privacy policy, and support
+# documentation.
+
+# GitHub recommends pinning actions to a commit SHA.
+# To get a newer version, you will need to update the SHA.
+# You can also reference a tag or branch, but the action may change without warning.
+
+name: Create and publish a Docker image
+
+on:
+ push:
+ branches: ['master']
+ create:
+env:
+ REGISTRY: ghcr.io
+ IMAGE_NAME: ${{ github.repository }}
+
+jobs:
+ build-and-push-image:
+ if: github.repository == 'zhayujie/chatgpt-on-wechat'
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ packages: write
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v3
+
+ - name: Login to Docker Hub
+ uses: docker/login-action@v2
+ with:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+ - name: Log in to the Container registry
+ uses: docker/login-action@v2
+ with:
+ registry: ${{ env.REGISTRY }}
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Extract metadata (tags, labels) for Docker
+ id: meta
+ uses: docker/metadata-action@v4
+ with:
+ images: |
+ ${{ env.IMAGE_NAME }}
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
+
+ - name: Build and push Docker image
+ uses: docker/build-push-action@v3
+ with:
+ context: .
+ push: true
+ file: ./docker/Dockerfile.latest
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
+
+ - uses: actions/delete-package-versions@v4
+ with:
+ package-name: 'chatgpt-on-wechat'
+ package-type: 'container'
+ min-versions-to-keep: 10
+ delete-only-untagged-versions: 'true'
+ token: ${{ secrets.GITHUB_TOKEN }}
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..560e615
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,33 @@
+.DS_Store
+.idea
+.vscode
+.venv
+.vs
+.wechaty/
+__pycache__/
+venv*
+*.pyc
+config.json
+QR.png
+nohup.out
+tmp
+plugins.json
+itchat.pkl
+*.log
+user_datas.pkl
+chatgpt_tool_hub/
+plugins/**/
+!plugins/bdunit
+!plugins/dungeon
+!plugins/finish
+!plugins/godcmd
+!plugins/tool
+!plugins/banwords
+!plugins/banwords/**/
+plugins/banwords/__pycache__
+plugins/banwords/lib/__pycache__
+!plugins/hello
+!plugins/role
+!plugins/keyword
+!plugins/linkai
+client_config.json
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..5dd0d7d
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,30 @@
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ - id: fix-byte-order-marker
+ - id: check-case-conflict
+ - id: check-merge-conflict
+ - id: debug-statements
+ - id: pretty-format-json
+ types: [text]
+ files: \.json(.template)?$
+ args: [ --autofix , --no-ensure-ascii, --indent=2, --no-sort-keys]
+ - id: trailing-whitespace
+ exclude: '(\/|^)lib\/'
+ args: [ --markdown-linebreak-ext=md ]
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ exclude: '(\/|^)lib\/'
+ - repo: https://github.com/psf/black
+ rev: 23.3.0
+ hooks:
+ - id: black
+ exclude: '(\/|^)lib\/'
+ - repo: https://github.com/PyCQA/flake8
+ rev: 6.0.0
+ hooks:
+ - id: flake8
+ exclude: '(\/|^)lib\/'
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..9d3bbb7
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,3 @@
+FROM ghcr.io/zhayujie/chatgpt-on-wechat:latest
+
+ENTRYPOINT ["/entrypoint.sh"]
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..fd03f7c
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2022 zhayujie
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..46541d0
--- /dev/null
+++ b/README.md
@@ -0,0 +1,281 @@
+# 简介
+
+> 本项目是基于大模型的智能对话机器人,支持微信、企业微信、公众号、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/文心一言/讯飞星火/通义千问/Gemini/LinkAI/ZhipuAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
+
+最新版本支持的功能如下:
+
+- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信、微信公众号和、企业微信、飞书、钉钉等部署方式
+- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4, claude, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM
+- [x] **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
+- [x] **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型
+- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件
+- [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、智能客服、私域助手使用,基于 [LinkAI](https://link-ai.tech) 实现
+
+# 演示
+
+https://github.com/zhayujie/chatgpt-on-wechat/assets/26161723/d5154020-36e3-41db-8706-40ce9f3f1b1e
+
+Demo made by [Visionn](https://www.wangpc.cc/)
+
+# 商业支持
+
+> 我们还提供企业级的 **AI应用平台**,包含知识库、Agent插件、应用管理等能力,支持多平台聚合的应用接入、客户端管理、对话管理,以及提供
+SaaS服务、私有化部署、稳定托管接入 等多种模式。
+>
+> 目前已在私域运营、智能客服、企业效率助手等场景积累了丰富的 AI 解决方案, 在电商、文教、健康、新消费等各行业沉淀了 AI 落地的最佳实践,致力于打造助力中小企业拥抱 AI 的一站式平台。
+
+企业服务和商用咨询可联系产品顾问:
+
+
+
+# 开源社区
+
+添加小助手微信加入开源项目交流群:
+
+
+
+# 更新日志
+
+>**2023.11.11:** [1.5.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.3) 和 [1.5.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.4),新增Google Gemini、通义千问模型
+
+>**2023.11.10:** [1.5.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.2),新增飞书通道、图像识别对话、黑名单配置
+
+>**2023.11.10:** [1.5.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.0),新增 `gpt-4-turbo`, `dall-e-3`, `tts` 模型接入,完善图像理解&生成、语音识别&生成的多模态能力
+
+>**2023.10.16:** 支持通过意图识别使用LinkAI联网搜索、数学计算、网页访问等插件,参考[插件文档](https://docs.link-ai.tech/platform/plugins)
+
+>**2023.09.26:** 插件增加 文件/文章链接 一键总结和对话的功能,使用参考:[插件说明](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai#3%E6%96%87%E6%A1%A3%E6%80%BB%E7%BB%93%E5%AF%B9%E8%AF%9D%E5%8A%9F%E8%83%BD)
+
+>**2023.08.08:** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图
+
+>**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,并接入微信、公众号及企业微信中,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
+
+>**2023.04.26:** 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatcom/README.md)。(contributed by [@lanvent](https://github.com/lanvent) in [#944](https://github.com/zhayujie/chatgpt-on-wechat/pull/944))
+
+>**2023.04.05:** 支持微信公众号部署,兼容插件,并支持语音图片交互,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
+
+>**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
+
+>**2023.03.25:** 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 [#578](https://github.com/zhayujie/chatgpt-on-wechat/issues/578)。(contributed by [@lanvent](https://github.com/lanvent) in [#565](https://github.com/zhayujie/chatgpt-on-wechat/pull/565))
+
+>**2023.03.09:** 基于 `whisper API`(后续已接入更多的语音`API`服务) 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
+
+>**2023.02.09:** 扫码登录存在账号限制风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
+
+# 快速开始
+
+快速开始文档:[项目搭建文档](https://docs.link-ai.tech/cow/quick-start)
+
+## 准备
+
+### 1. 账号注册
+
+项目默认使用OpenAI接口,需前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。接口需要海外网络访问及绑定信用卡支付。
+
+> 默认对话模型是 openai 的 gpt-3.5-turbo,计费方式是约每 1000tokens (约750个英文单词 或 500汉字,包含请求和回复) 消耗 $0.002,图片生成是Dell E模型,每张消耗 $0.016。
+
+项目同时也支持使用 LinkAI 接口,无需代理,可使用 文心、讯飞、GPT-3、GPT-4 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结和对话等能力。修改配置即可一键切换,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
+
+### 2.运行环境
+
+支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
+> 建议Python版本在 3.7.1~3.9.X 之间,推荐3.8版本,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
+
+> 注意:Docker 或 Railway 部署无需安装python环境和下载源码,可直接快进到下一节。
+
+**(1) 克隆项目代码:**
+
+```bash
+git clone https://github.com/zhayujie/chatgpt-on-wechat
+cd chatgpt-on-wechat/
+```
+
+注: 如遇到网络问题可选择国内镜像 https://gitee.com/zhayujie/chatgpt-on-wechat
+
+**(2) 安装核心依赖 (必选):**
+> 能够使用`itchat`创建机器人,并具有文字交流功能所需的最小依赖集合。
+```bash
+pip3 install -r requirements.txt
+```
+
+**(3) 拓展依赖 (可选,建议安装):**
+
+```bash
+pip3 install -r requirements-optional.txt
+```
+> 如果某项依赖安装失败可注释掉对应的行再继续
+
+## 配置
+
+配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
+
+```bash
+ cp config-template.json config.json
+```
+
+然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(请去掉注释):
+
+```bash
+# config.json文件内容示例
+{
+ "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
+ "model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
+ "proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
+ "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
+ "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
+ "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
+ "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
+ "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
+ "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
+ "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
+ "speech_recognition": false, # 是否开启语音识别
+ "group_speech_recognition": false, # 是否开启群组语音识别
+ "use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
+ "azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称
+ "azure_api_version": "", # 采用Azure ChatGPT时,API版本
+ "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
+ # 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
+ "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
+ "use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
+ "linkai_api_key": "", # LinkAI Api Key
+ "linkai_app_code": "" # LinkAI 应用code
+}
+```
+**配置说明:**
+
+**1.个人聊天**
+
++ 个人聊天中,需要以 "bot"或"@bot" 为开头的内容触发机器人,对应配置项 `single_chat_prefix` (如果不需要以前缀触发可以填写 `"single_chat_prefix": [""]`)
++ 机器人回复的内容会以 "[bot] " 作为前缀, 以区分真人,对应的配置项为 `single_chat_reply_prefix` (如果不需要前缀可以填写 `"single_chat_reply_prefix": ""`)
+
+**2.群组聊天**
+
++ 群组聊天中,群名称需配置在 `group_name_white_list ` 中才能开启群聊自动回复。如果想对所有群聊生效,可以直接填写 `"group_name_white_list": ["ALL_GROUP"]`
++ 默认只要被人 @ 就会触发机器人自动回复;另外群聊天中只要检测到以 "@bot" 开头的内容,同样会自动回复(方便自己触发),这对应配置项 `group_chat_prefix`
++ 可选配置: `group_name_keyword_white_list`配置项支持模糊匹配群名称,`group_chat_keyword`配置项则支持模糊匹配群消息内容,用法与上述两个配置项相同。(Contributed by [evolay](https://github.com/evolay))
++ `group_chat_in_one_session`:使群聊共享一个会话上下文,配置 `["ALL_GROUP"]` 则作用于所有群聊
+
+**3.语音识别**
+
++ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
++ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
++ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊),但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。
+
+**4.其他配置**
+
++ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`, `wenxin` , `claude` , `xunfei`(其中gpt-4 api暂未完全开放,申请通过后可使用)
++ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
++ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
++ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
++ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档,在[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中检查哪些参数在本项目中是可配置的。
++ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
++ `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
++ `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
++ `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。
++ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
++ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
+
+**5.LinkAI配置 (可选)**
+
++ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
++ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
++ `linkai_app_code`: LinkAI 应用code,选填
+
+**本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
+
+## 运行
+
+### 1.本地运行
+
+如果是开发机 **本地运行**,直接在项目根目录下执行:
+
+```bash
+python3 app.py # windows环境下该命令通常为 python app.py
+```
+
+终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
+
+### 2.服务器部署
+
+使用nohup命令在后台运行程序:
+
+```bash
+nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
+```
+扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。
+
+> **多账号支持:** 将项目复制多份,分别启动程序,用不同账号扫码登录即可实现同时运行。
+
+> **特殊指令:** 用户向机器人发送 **#reset** 即可清空该用户的上下文记忆。
+
+
+### 3.Docker部署
+
+> 使用docker部署无需下载源码和安装依赖,只需要获取 docker-compose.yml 配置文件并启动容器即可。
+
+> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功的表现是执行 `docker -v` 和 `docker-compose version` (或 docker compose version) 可以查看到版本号,可前往 [docker官网](https://docs.docker.com/engine/install/) 进行下载。
+
+#### (1) 下载 docker-compose.yml 文件
+
+```bash
+wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml
+```
+
+下载完成后打开 `docker-compose.yml` 修改所需配置,如 `OPEN_AI_API_KEY` 和 `GROUP_NAME_WHITE_LIST` 等。
+
+#### (2) 启动容器
+
+在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
+
+```bash
+sudo docker compose up -d
+```
+
+运行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。
+
+注意:
+
+ - 如果 `docker-compose` 是 1.X 版本 则需要执行 `sudo docker-compose up -d` 来启动容器
+ - 该命令会自动去 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取 latest 版本的镜像,latest 镜像会在每次项目 release 新的版本时生成
+
+最后运行以下命令可查看容器运行日志,扫描日志中的二维码即可完成登录:
+
+```bash
+sudo docker logs -f chatgpt-on-wechat
+```
+
+#### (3) 插件使用
+
+如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
+重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
+
+```
+volumes:
+ - ./config.json:/app/plugins/config.json
+```
+
+### 4. Railway部署
+
+> Railway 每月提供5刀和最多500小时的免费额度。 (07.11更新: 目前大部分账号已无法免费部署)
+
+1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)
+2. 点击 `Deploy Now` 按钮。
+3. 设置环境变量来重载程序运行的参数,例如`open_ai_api_key`, `character_desc`。
+
+**一键部署:**
+
+ [![Deploy on Railway](https://railway.app/button.svg)](https://railway.app/template/qApznZ?referralCode=RC3znh)
+
+## 常见问题
+
+FAQs:
+
+或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (beta版本,语料完善中,回复仅供参考)
+
+## 开发
+
+欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py) 实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
+
+## 联系
+
+欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品顾问](https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg)咨询。
diff --git a/app.py b/app.py
new file mode 100644
index 0000000..ff2a6c7
--- /dev/null
+++ b/app.py
@@ -0,0 +1,71 @@
+# encoding:utf-8
+
+import os
+import signal
+import sys
+import time
+
+from channel import channel_factory
+from common import const
+from config import load_config
+from plugins import *
+import threading
+
+
+def sigterm_handler_wrap(_signo):
+ old_handler = signal.getsignal(_signo)
+
+ def func(_signo, _stack_frame):
+ logger.info("signal {} received, exiting...".format(_signo))
+ conf().save_user_datas()
+ if callable(old_handler): # check old_handler
+ return old_handler(_signo, _stack_frame)
+ sys.exit(0)
+
+ signal.signal(_signo, func)
+
+
+def start_channel(channel_name: str):
+ channel = channel_factory.create_channel(channel_name)
+ if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework",
+ const.FEISHU, const.DINGTALK]:
+ PluginManager().load_plugins()
+
+ if conf().get("use_linkai"):
+ try:
+ from common import linkai_client
+ threading.Thread(target=linkai_client.start, args=(channel,)).start()
+ except Exception as e:
+ pass
+ channel.startup()
+
+
+def run():
+ try:
+ # load config
+ load_config()
+ # ctrl + c
+ sigterm_handler_wrap(signal.SIGINT)
+ # kill signal
+ sigterm_handler_wrap(signal.SIGTERM)
+
+ # create channel
+ channel_name = conf().get("channel_type", "wx")
+
+ if "--cmd" in sys.argv:
+ channel_name = "terminal"
+
+ if channel_name == "wxy":
+ os.environ["WECHATY_LOG"] = "warn"
+
+ start_channel(channel_name)
+
+ while True:
+ time.sleep(1)
+ except Exception as e:
+ logger.error("App startup failed!")
+ logger.exception(e)
+
+
+if __name__ == "__main__":
+ run()
diff --git a/bot/ali/ali_qwen_bot.py b/bot/ali/ali_qwen_bot.py
new file mode 100644
index 0000000..ae9d767
--- /dev/null
+++ b/bot/ali/ali_qwen_bot.py
@@ -0,0 +1,214 @@
+# encoding:utf-8
+
+import json
+import time
+from typing import List, Tuple
+
+import openai
+import openai.error
+import broadscope_bailian
+from broadscope_bailian import ChatQaMessage
+
+from bot.bot import Bot
+from bot.ali.ali_qwen_session import AliQwenSession
+from bot.session_manager import SessionManager
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from common import const
+from config import conf, load_config
+
+class AliQwenBot(Bot):
+ def __init__(self):
+ super().__init__()
+ self.api_key_expired_time = self.set_api_key()
+ self.sessions = SessionManager(AliQwenSession, model=conf().get("model", const.QWEN))
+
+ def api_key_client(self):
+ return broadscope_bailian.AccessTokenClient(access_key_id=self.access_key_id(), access_key_secret=self.access_key_secret())
+
+ def access_key_id(self):
+ return conf().get("qwen_access_key_id")
+
+ def access_key_secret(self):
+ return conf().get("qwen_access_key_secret")
+
+ def agent_key(self):
+ return conf().get("qwen_agent_key")
+
+ def app_id(self):
+ return conf().get("qwen_app_id")
+
+ def node_id(self):
+ return conf().get("qwen_node_id", "")
+
+ def temperature(self):
+ return conf().get("temperature", 0.2 )
+
+ def top_p(self):
+ return conf().get("top_p", 1)
+
+ def reply(self, query, context=None):
+ # acquire reply content
+ if context.type == ContextType.TEXT:
+ logger.info("[QWEN] query={}".format(query))
+
+ session_id = context["session_id"]
+ reply = None
+ clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
+ if query in clear_memory_commands:
+ self.sessions.clear_session(session_id)
+ reply = Reply(ReplyType.INFO, "记忆已清除")
+ elif query == "#清除所有":
+ self.sessions.clear_all_session()
+ reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+ elif query == "#更新配置":
+ load_config()
+ reply = Reply(ReplyType.INFO, "配置已更新")
+ if reply:
+ return reply
+ session = self.sessions.session_query(query, session_id)
+ logger.debug("[QWEN] session query={}".format(session.messages))
+
+ reply_content = self.reply_text(session)
+ logger.debug(
+ "[QWEN] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
+ session.messages,
+ session_id,
+ reply_content["content"],
+ reply_content["completion_tokens"],
+ )
+ )
+ if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
+ reply = Reply(ReplyType.ERROR, reply_content["content"])
+ elif reply_content["completion_tokens"] > 0:
+ self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
+ reply = Reply(ReplyType.TEXT, reply_content["content"])
+ else:
+ reply = Reply(ReplyType.ERROR, reply_content["content"])
+ logger.debug("[QWEN] reply {} used 0 tokens.".format(reply_content))
+ return reply
+
+ else:
+ reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+ return reply
+
+ def reply_text(self, session: AliQwenSession, retry_count=0) -> dict:
+ """
+ call bailian's ChatCompletion to get the answer
+ :param session: a conversation session
+ :param retry_count: retry count
+ :return: {}
+ """
+ try:
+ prompt, history = self.convert_messages_format(session.messages)
+ self.update_api_key_if_expired()
+ # NOTE 阿里百炼的call()函数未提供temperature参数,考虑到temperature和top_p参数作用相同,取两者较小的值作为top_p参数传入,详情见文档 https://help.aliyun.com/document_detail/2587502.htm
+ response = broadscope_bailian.Completions().call(app_id=self.app_id(), prompt=prompt, history=history, top_p=min(self.temperature(), self.top_p()))
+ completion_content = self.get_completion_content(response, self.node_id())
+ completion_tokens, total_tokens = self.calc_tokens(session.messages, completion_content)
+ return {
+ "total_tokens": total_tokens,
+ "completion_tokens": completion_tokens,
+ "content": completion_content,
+ }
+ except Exception as e:
+ need_retry = retry_count < 2
+ result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+ if isinstance(e, openai.error.RateLimitError):
+ logger.warn("[QWEN] RateLimitError: {}".format(e))
+ result["content"] = "提问太快啦,请休息一下再问我吧"
+ if need_retry:
+ time.sleep(20)
+ elif isinstance(e, openai.error.Timeout):
+ logger.warn("[QWEN] Timeout: {}".format(e))
+ result["content"] = "我没有收到你的消息"
+ if need_retry:
+ time.sleep(5)
+ elif isinstance(e, openai.error.APIError):
+ logger.warn("[QWEN] Bad Gateway: {}".format(e))
+ result["content"] = "请再问我一次"
+ if need_retry:
+ time.sleep(10)
+ elif isinstance(e, openai.error.APIConnectionError):
+ logger.warn("[QWEN] APIConnectionError: {}".format(e))
+ need_retry = False
+ result["content"] = "我连接不到你的网络"
+ else:
+ logger.exception("[QWEN] Exception: {}".format(e))
+ need_retry = False
+ self.sessions.clear_session(session.session_id)
+
+ if need_retry:
+ logger.warn("[QWEN] 第{}次重试".format(retry_count + 1))
+ return self.reply_text(session, retry_count + 1)
+ else:
+ return result
+
+ def set_api_key(self):
+ api_key, expired_time = self.api_key_client().create_token(agent_key=self.agent_key())
+ broadscope_bailian.api_key = api_key
+ return expired_time
+
+ def update_api_key_if_expired(self):
+ if time.time() > self.api_key_expired_time:
+ self.api_key_expired_time = self.set_api_key()
+
+ def convert_messages_format(self, messages) -> Tuple[str, List[ChatQaMessage]]:
+ history = []
+ user_content = ''
+ assistant_content = ''
+ system_content = ''
+ for message in messages:
+ role = message.get('role')
+ if role == 'user':
+ user_content += message.get('content')
+ elif role == 'assistant':
+ assistant_content = message.get('content')
+ history.append(ChatQaMessage(user_content, assistant_content))
+ user_content = ''
+ assistant_content = ''
+ elif role =='system':
+ system_content += message.get('content')
+ if user_content == '':
+ raise Exception('no user message')
+ if system_content != '':
+ # NOTE 模拟系统消息,测试发现人格描述以"你需要扮演ChatGPT"开头能够起作用,而以"你是ChatGPT"开头模型会直接否认
+ system_qa = ChatQaMessage(system_content, '好的,我会严格按照你的设定回答问题')
+ history.insert(0, system_qa)
+ logger.debug("[QWEN] converted qa messages: {}".format([item.to_dict() for item in history]))
+ logger.debug("[QWEN] user content as prompt: {}".format(user_content))
+ return user_content, history
+
+ def get_completion_content(self, response, node_id):
+ if not response['Success']:
+ return f"[ERROR]\n{response['Code']}:{response['Message']}"
+ text = response['Data']['Text']
+ if node_id == '':
+ return text
+ # TODO: 当使用流程编排创建大模型应用时,响应结构如下,最终结果在['finalResult'][node_id]['response']['text']中,暂时先这么写
+ # {
+ # 'Success': True,
+ # 'Code': None,
+ # 'Message': None,
+ # 'Data': {
+ # 'ResponseId': '9822f38dbacf4c9b8daf5ca03a2daf15',
+ # 'SessionId': 'session_id',
+ # 'Text': '{"finalResult":{"LLM_T7islK":{"params":{"modelId":"qwen-plus-v1","prompt":"${systemVars.query}${bizVars.Text}"},"response":{"text":"作为一个AI语言模型,我没有年龄,因为我没有生日。\n我只是一个程序,没有生命和身体。"}}}}',
+ # 'Thoughts': [],
+ # 'Debug': {},
+ # 'DocReferences': []
+ # },
+ # 'RequestId': '8e11d31551ce4c3f83f49e6e0dd998b0',
+ # 'Failed': None
+ # }
+ text_dict = json.loads(text)
+ completion_content = text_dict['finalResult'][node_id]['response']['text']
+ return completion_content
+
+ def calc_tokens(self, messages, completion_content):
+ completion_tokens = len(completion_content)
+ prompt_tokens = 0
+ for message in messages:
+ prompt_tokens += len(message["content"])
+ return completion_tokens, prompt_tokens + completion_tokens
diff --git a/bot/ali/ali_qwen_session.py b/bot/ali/ali_qwen_session.py
new file mode 100644
index 0000000..0eb1c4a
--- /dev/null
+++ b/bot/ali/ali_qwen_session.py
@@ -0,0 +1,62 @@
+from bot.session_manager import Session
+from common.log import logger
+
+"""
+ e.g.
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Who won the world series in 2020?"},
+ {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
+ {"role": "user", "content": "Where was it played?"}
+ ]
+"""
+
+class AliQwenSession(Session):
+ def __init__(self, session_id, system_prompt=None, model="qianwen"):
+ super().__init__(session_id, system_prompt)
+ self.model = model
+ self.reset()
+
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
+ precise = True
+ try:
+ cur_tokens = self.calc_tokens()
+ except Exception as e:
+ precise = False
+ if cur_tokens is None:
+ raise e
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+ while cur_tokens > max_tokens:
+ if len(self.messages) > 2:
+ self.messages.pop(1)
+ elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
+ self.messages.pop(1)
+ if precise:
+ cur_tokens = self.calc_tokens()
+ else:
+ cur_tokens = cur_tokens - max_tokens
+ break
+ elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
+ logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
+ break
+ else:
+ logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
+ break
+ if precise:
+ cur_tokens = self.calc_tokens()
+ else:
+ cur_tokens = cur_tokens - max_tokens
+ return cur_tokens
+
+ def calc_tokens(self):
+ return num_tokens_from_messages(self.messages, self.model)
+
+def num_tokens_from_messages(messages, model):
+ """Returns the number of tokens used by a list of messages."""
+ # 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词"
+ # 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html
+ # 目前根据字符串长度粗略估计token数,不影响正常使用
+ tokens = 0
+ for msg in messages:
+ tokens += len(msg["content"])
+ return tokens
diff --git a/bot/baidu/baidu_unit_bot.py b/bot/baidu/baidu_unit_bot.py
new file mode 100644
index 0000000..f7714e4
--- /dev/null
+++ b/bot/baidu/baidu_unit_bot.py
@@ -0,0 +1,36 @@
+# encoding:utf-8
+
+import requests
+
+from bot.bot import Bot
+from bridge.reply import Reply, ReplyType
+
+
+# Baidu Unit对话接口 (可用, 但能力较弱)
+class BaiduUnitBot(Bot):
+ def reply(self, query, context=None):
+ token = self.get_token()
+ url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
+ post_data = (
+ '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
+ + query
+ + '", "hyper_params": {"chat_custom_bot_profile": 1}}}'
+ )
+ print(post_data)
+ headers = {"content-type": "application/x-www-form-urlencoded"}
+ response = requests.post(url, data=post_data.encode(), headers=headers)
+ if response:
+ reply = Reply(
+ ReplyType.TEXT,
+ response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1],
+ )
+ return reply
+
+ def get_token(self):
+ access_key = "YOUR_ACCESS_KEY"
+ secret_key = "YOUR_SECRET_KEY"
+ host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
+ response = requests.get(host)
+ if response:
+ print(response.json())
+ return response.json()["access_token"]
diff --git a/bot/baidu/baidu_wenxin.py b/bot/baidu/baidu_wenxin.py
new file mode 100644
index 0000000..f35e0fa
--- /dev/null
+++ b/bot/baidu/baidu_wenxin.py
@@ -0,0 +1,107 @@
+# encoding:utf-8
+
+import requests, json
+from bot.bot import Bot
+from bot.session_manager import SessionManager
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
+
+BAIDU_API_KEY = conf().get("baidu_wenxin_api_key")
+BAIDU_SECRET_KEY = conf().get("baidu_wenxin_secret_key")
+
+class BaiduWenxinBot(Bot):
+
+ def __init__(self):
+ super().__init__()
+ wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant"
+ if conf().get("model") and conf().get("model") == "wenxin-4":
+ wenxin_model = "completions_pro"
+ self.sessions = SessionManager(BaiduWenxinSession, model=wenxin_model)
+
+ def reply(self, query, context=None):
+ # acquire reply content
+ if context and context.type:
+ if context.type == ContextType.TEXT:
+ logger.info("[BAIDU] query={}".format(query))
+ session_id = context["session_id"]
+ reply = None
+ if query == "#清除记忆":
+ self.sessions.clear_session(session_id)
+ reply = Reply(ReplyType.INFO, "记忆已清除")
+ elif query == "#清除所有":
+ self.sessions.clear_all_session()
+ reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+ else:
+ session = self.sessions.session_query(query, session_id)
+ result = self.reply_text(session)
+ total_tokens, completion_tokens, reply_content = (
+ result["total_tokens"],
+ result["completion_tokens"],
+ result["content"],
+ )
+ logger.debug(
+ "[BAIDU] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content, completion_tokens)
+ )
+
+ if total_tokens == 0:
+ reply = Reply(ReplyType.ERROR, reply_content)
+ else:
+ self.sessions.session_reply(reply_content, session_id, total_tokens)
+ reply = Reply(ReplyType.TEXT, reply_content)
+ return reply
+ elif context.type == ContextType.IMAGE_CREATE:
+ ok, retstring = self.create_img(query, 0)
+ reply = None
+ if ok:
+ reply = Reply(ReplyType.IMAGE_URL, retstring)
+ else:
+ reply = Reply(ReplyType.ERROR, retstring)
+ return reply
+
+ def reply_text(self, session: BaiduWenxinSession, retry_count=0):
+ try:
+ logger.info("[BAIDU] model={}".format(session.model))
+ access_token = self.get_access_token()
+ if access_token == 'None':
+ logger.warn("[BAIDU] access token 获取失败")
+ return {
+ "total_tokens": 0,
+ "completion_tokens": 0,
+ "content": 0,
+ }
+ url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + session.model + "?access_token=" + access_token
+ headers = {
+ 'Content-Type': 'application/json'
+ }
+ payload = {'messages': session.messages}
+ response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
+ response_text = json.loads(response.text)
+ logger.info(f"[BAIDU] response text={response_text}")
+ res_content = response_text["result"]
+ total_tokens = response_text["usage"]["total_tokens"]
+ completion_tokens = response_text["usage"]["completion_tokens"]
+ logger.info("[BAIDU] reply={}".format(res_content))
+ return {
+ "total_tokens": total_tokens,
+ "completion_tokens": completion_tokens,
+ "content": res_content,
+ }
+ except Exception as e:
+ need_retry = retry_count < 2
+ logger.warn("[BAIDU] Exception: {}".format(e))
+ need_retry = False
+ self.sessions.clear_session(session.session_id)
+ result = {"completion_tokens": 0, "content": "出错了: {}".format(e)}
+ return result
+
+ def get_access_token(self):
+ """
+ 使用 AK,SK 生成鉴权签名(Access Token)
+ :return: access_token,或是None(如果错误)
+ """
+ url = "https://aip.baidubce.com/oauth/2.0/token"
+ params = {"grant_type": "client_credentials", "client_id": BAIDU_API_KEY, "client_secret": BAIDU_SECRET_KEY}
+ return str(requests.post(url, params=params).json().get("access_token"))
diff --git a/bot/baidu/baidu_wenxin_session.py b/bot/baidu/baidu_wenxin_session.py
new file mode 100644
index 0000000..5ba2f17
--- /dev/null
+++ b/bot/baidu/baidu_wenxin_session.py
@@ -0,0 +1,53 @@
+from bot.session_manager import Session
+from common.log import logger
+
+"""
+ e.g. [
+ {"role": "user", "content": "Who won the world series in 2020?"},
+ {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
+ {"role": "user", "content": "Where was it played?"}
+ ]
+"""
+
+
+class BaiduWenxinSession(Session):
+ def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
+ super().__init__(session_id, system_prompt)
+ self.model = model
+ # 百度文心不支持system prompt
+ # self.reset()
+
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
+ precise = True
+ try:
+ cur_tokens = self.calc_tokens()
+ except Exception as e:
+ precise = False
+ if cur_tokens is None:
+ raise e
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+ while cur_tokens > max_tokens:
+ if len(self.messages) >= 2:
+ self.messages.pop(0)
+ self.messages.pop(0)
+ else:
+ logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
+ break
+ if precise:
+ cur_tokens = self.calc_tokens()
+ else:
+ cur_tokens = cur_tokens - max_tokens
+ return cur_tokens
+
+ def calc_tokens(self):
+ return num_tokens_from_messages(self.messages, self.model)
+
+
+def num_tokens_from_messages(messages, model):
+ """Returns the number of tokens used by a list of messages."""
+ tokens = 0
+ for msg in messages:
+ # 官方token计算规则暂不明确: "大约为 token数为 "中文字 + 其他语种单词数 x 1.3"
+ # 这里先直接根据字数粗略估算吧,暂不影响正常使用,仅在判断是否丢弃历史会话的时候会有偏差
+ tokens += len(msg["content"])
+ return tokens
diff --git a/bot/bot.py b/bot/bot.py
new file mode 100644
index 0000000..ca6e1aa
--- /dev/null
+++ b/bot/bot.py
@@ -0,0 +1,17 @@
+"""
+Auto-replay chat robot abstract class
+"""
+
+
+from bridge.context import Context
+from bridge.reply import Reply
+
+
+class Bot(object):
+ def reply(self, query, context: Context = None) -> Reply:
+ """
+ bot auto-reply content
+ :param req: received message
+ :return: reply content
+ """
+ raise NotImplementedError
diff --git a/bot/bot_factory.py b/bot/bot_factory.py
new file mode 100644
index 0000000..2046da7
--- /dev/null
+++ b/bot/bot_factory.py
@@ -0,0 +1,60 @@
+"""
+channel factory
+"""
+from common import const
+
+
+def create_bot(bot_type):
+ """
+ create a bot_type instance
+ :param bot_type: bot type code
+ :return: bot instance
+ """
+ if bot_type == const.BAIDU:
+ # 替换Baidu Unit为Baidu文心千帆对话接口
+ # from bot.baidu.baidu_unit_bot import BaiduUnitBot
+ # return BaiduUnitBot()
+ from bot.baidu.baidu_wenxin import BaiduWenxinBot
+ return BaiduWenxinBot()
+
+ elif bot_type == const.CHATGPT:
+ # ChatGPT 网页端web接口
+ from bot.chatgpt.chat_gpt_bot import ChatGPTBot
+ return ChatGPTBot()
+
+ elif bot_type == const.OPEN_AI:
+ # OpenAI 官方对话模型API
+ from bot.openai.open_ai_bot import OpenAIBot
+ return OpenAIBot()
+
+ elif bot_type == const.CHATGPTONAZURE:
+ # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
+ from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
+ return AzureChatGPTBot()
+
+ elif bot_type == const.XUNFEI:
+ from bot.xunfei.xunfei_spark_bot import XunFeiBot
+ return XunFeiBot()
+
+ elif bot_type == const.LINKAI:
+ from bot.linkai.link_ai_bot import LinkAIBot
+ return LinkAIBot()
+
+ elif bot_type == const.CLAUDEAI:
+ from bot.claude.claude_ai_bot import ClaudeAIBot
+ return ClaudeAIBot()
+
+ elif bot_type == const.QWEN:
+ from bot.ali.ali_qwen_bot import AliQwenBot
+ return AliQwenBot()
+
+ elif bot_type == const.GEMINI:
+ from bot.gemini.google_gemini_bot import GoogleGeminiBot
+ return GoogleGeminiBot()
+
+ elif bot_type == const.ZHIPU_AI:
+ from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
+ return ZHIPUAIBot()
+
+
+ raise RuntimeError
diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py
new file mode 100644
index 0000000..979ce4c
--- /dev/null
+++ b/bot/chatgpt/chat_gpt_bot.py
@@ -0,0 +1,194 @@
+# encoding:utf-8
+
+import time
+
+import openai
+import openai.error
+import requests
+
+from bot.bot import Bot
+from bot.chatgpt.chat_gpt_session import ChatGPTSession
+from bot.openai.open_ai_image import OpenAIImage
+from bot.session_manager import SessionManager
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from common.token_bucket import TokenBucket
+from config import conf, load_config
+
+
+# OpenAI对话模型API (可用)
+class ChatGPTBot(Bot, OpenAIImage):
+ def __init__(self):
+ super().__init__()
+ # set the default api_key
+ openai.api_key = conf().get("open_ai_api_key")
+ if conf().get("open_ai_api_base"):
+ openai.api_base = conf().get("open_ai_api_base")
+ proxy = conf().get("proxy")
+ if proxy:
+ openai.proxy = proxy
+ if conf().get("rate_limit_chatgpt"):
+ self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
+
+ self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
+ self.args = {
+ "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
+ "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
+ # "max_tokens":4096, # 回复最大的字符数
+ "top_p": conf().get("top_p", 1),
+ "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+ "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
+ }
+
+ def reply(self, query, context=None):
+ # acquire reply content
+ if context.type == ContextType.TEXT:
+ logger.info("[CHATGPT] query={}".format(query))
+
+ session_id = context["session_id"]
+ reply = None
+ clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
+ if query in clear_memory_commands:
+ self.sessions.clear_session(session_id)
+ reply = Reply(ReplyType.INFO, "记忆已清除")
+ elif query == "#清除所有":
+ self.sessions.clear_all_session()
+ reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+ elif query == "#更新配置":
+ load_config()
+ reply = Reply(ReplyType.INFO, "配置已更新")
+ if reply:
+ return reply
+ session = self.sessions.session_query(query, session_id)
+ logger.debug("[CHATGPT] session query={}".format(session.messages))
+
+ api_key = context.get("openai_api_key")
+ model = context.get("gpt_model")
+ new_args = None
+ if model:
+ new_args = self.args.copy()
+ new_args["model"] = model
+ # if context.get('stream'):
+ # # reply in stream
+ # return self.reply_text_stream(query, new_query, session_id)
+
+ reply_content = self.reply_text(session, api_key, args=new_args)
+ logger.debug(
+ "[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
+ session.messages,
+ session_id,
+ reply_content["content"],
+ reply_content["completion_tokens"],
+ )
+ )
+ if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
+ reply = Reply(ReplyType.ERROR, reply_content["content"])
+ elif reply_content["completion_tokens"] > 0:
+ self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
+ reply = Reply(ReplyType.TEXT, reply_content["content"])
+ else:
+ reply = Reply(ReplyType.ERROR, reply_content["content"])
+ logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
+ return reply
+
+ elif context.type == ContextType.IMAGE_CREATE:
+ ok, retstring = self.create_img(query, 0)
+ reply = None
+ if ok:
+ reply = Reply(ReplyType.IMAGE_URL, retstring)
+ else:
+ reply = Reply(ReplyType.ERROR, retstring)
+ return reply
+ else:
+ reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+ return reply
+
+ def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
+ """
+ call openai's ChatCompletion to get the answer
+ :param session: a conversation session
+ :param session_id: session id
+ :param retry_count: retry count
+ :return: {}
+ """
+ try:
+ if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
+ raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
+ # if api_key == None, the default openai.api_key will be used
+ if args is None:
+ args = self.args
+ response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
+ # logger.debug("[CHATGPT] response={}".format(response))
+ # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
+ return {
+ "total_tokens": response["usage"]["total_tokens"],
+ "completion_tokens": response["usage"]["completion_tokens"],
+ "content": response.choices[0]["message"]["content"],
+ }
+ except Exception as e:
+ need_retry = retry_count < 2
+ result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+ if isinstance(e, openai.error.RateLimitError):
+ logger.warn("[CHATGPT] RateLimitError: {}".format(e))
+ result["content"] = "提问太快啦,请休息一下再问我吧"
+ if need_retry:
+ time.sleep(20)
+ elif isinstance(e, openai.error.Timeout):
+ logger.warn("[CHATGPT] Timeout: {}".format(e))
+ result["content"] = "我没有收到你的消息"
+ if need_retry:
+ time.sleep(5)
+ elif isinstance(e, openai.error.APIError):
+ logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
+ result["content"] = "请再问我一次"
+ if need_retry:
+ time.sleep(10)
+ elif isinstance(e, openai.error.APIConnectionError):
+ logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
+ result["content"] = "我连接不到你的网络"
+ if need_retry:
+ time.sleep(5)
+ else:
+ logger.exception("[CHATGPT] Exception: {}".format(e))
+ need_retry = False
+ self.sessions.clear_session(session.session_id)
+
+ if need_retry:
+ logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
+ return self.reply_text(session, api_key, args, retry_count + 1)
+ else:
+ return result
+
+
+class AzureChatGPTBot(ChatGPTBot):
+ def __init__(self):
+ super().__init__()
+ openai.api_type = "azure"
+ openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
+ self.args["deployment_id"] = conf().get("azure_deployment_id")
+
+ def create_img(self, query, retry_count=0, api_key=None):
+ api_version = "2022-08-03-preview"
+ url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version)
+ api_key = api_key or openai.api_key
+ headers = {"api-key": api_key, "Content-Type": "application/json"}
+ try:
+ body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")}
+ submission = requests.post(url, headers=headers, json=body)
+ operation_location = submission.headers["Operation-Location"]
+ retry_after = submission.headers["Retry-after"]
+ status = ""
+ image_url = ""
+ while status != "Succeeded":
+ logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
+ time.sleep(int(retry_after))
+ response = requests.get(operation_location, headers=headers)
+ status = response.json()["status"]
+ image_url = response.json()["result"]["contentUrl"]
+ return True, image_url
+ except Exception as e:
+ logger.error("create image error: {}".format(e))
+ return False, "图片生成失败"
diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py
new file mode 100644
index 0000000..c4b17fd
--- /dev/null
+++ b/bot/chatgpt/chat_gpt_session.py
@@ -0,0 +1,102 @@
+from bot.session_manager import Session
+from common.log import logger
+from common import const
+
+"""
+ e.g. [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Who won the world series in 2020?"},
+ {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
+ {"role": "user", "content": "Where was it played?"}
+ ]
+"""
+
+
+class ChatGPTSession(Session):
+ def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
+ super().__init__(session_id, system_prompt)
+ self.model = model
+ self.reset()
+
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
+ precise = True
+ try:
+ cur_tokens = self.calc_tokens()
+ except Exception as e:
+ precise = False
+ if cur_tokens is None:
+ raise e
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+ while cur_tokens > max_tokens:
+ if len(self.messages) > 2:
+ self.messages.pop(1)
+ elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
+ self.messages.pop(1)
+ if precise:
+ cur_tokens = self.calc_tokens()
+ else:
+ cur_tokens = cur_tokens - max_tokens
+ break
+ elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
+ logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
+ break
+ else:
+ logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
+ break
+ if precise:
+ cur_tokens = self.calc_tokens()
+ else:
+ cur_tokens = cur_tokens - max_tokens
+ return cur_tokens
+
+ def calc_tokens(self):
+ return num_tokens_from_messages(self.messages, self.model)
+
+
+# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+def num_tokens_from_messages(messages, model):
+ """Returns the number of tokens used by a list of messages."""
+
+ if model in ["wenxin", "xunfei", const.GEMINI]:
+ return num_tokens_by_character(messages)
+
+ import tiktoken
+
+ if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]:
+ return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
+ elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
+ "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",
+ "gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]:
+ return num_tokens_from_messages(messages, model="gpt-4")
+
+ try:
+ encoding = tiktoken.encoding_for_model(model)
+ except KeyError:
+ logger.debug("Warning: model not found. Using cl100k_base encoding.")
+ encoding = tiktoken.get_encoding("cl100k_base")
+ if model == "gpt-3.5-turbo":
+ tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
+ tokens_per_name = -1 # if there's a name, the role is omitted
+ elif model == "gpt-4":
+ tokens_per_message = 3
+ tokens_per_name = 1
+ else:
+ logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
+ return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
+ num_tokens = 0
+ for message in messages:
+ num_tokens += tokens_per_message
+ for key, value in message.items():
+ num_tokens += len(encoding.encode(value))
+ if key == "name":
+ num_tokens += tokens_per_name
+ num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
+ return num_tokens
+
+
+def num_tokens_by_character(messages):
+ """Returns the number of tokens used by a list of messages."""
+ tokens = 0
+ for msg in messages:
+ tokens += len(msg["content"])
+ return tokens
diff --git a/bot/claude/claude_ai_bot.py b/bot/claude/claude_ai_bot.py
new file mode 100644
index 0000000..faad274
--- /dev/null
+++ b/bot/claude/claude_ai_bot.py
@@ -0,0 +1,222 @@
+import re
+import time
+import json
+import uuid
+from curl_cffi import requests
+from bot.bot import Bot
+from bot.claude.claude_ai_session import ClaudeAiSession
+from bot.openai.open_ai_image import OpenAIImage
+from bot.session_manager import SessionManager
+from bridge.context import Context, ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+
+
+class ClaudeAIBot(Bot, OpenAIImage):
+ def __init__(self):
+ super().__init__()
+ self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
+ self.claude_api_cookie = conf().get("claude_api_cookie")
+ self.proxy = conf().get("proxy")
+ self.con_uuid_dic = {}
+ if self.proxy:
+ self.proxies = {
+ "http": self.proxy,
+ "https": self.proxy
+ }
+ else:
+ self.proxies = None
+ self.error = ""
+ self.org_uuid = self.get_organization_id()
+
+ def generate_uuid(self):
+ random_uuid = uuid.uuid4()
+ random_uuid_str = str(random_uuid)
+ formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
+ return formatted_uuid
+
+ def reply(self, query, context: Context = None) -> Reply:
+ if context.type == ContextType.TEXT:
+ return self._chat(query, context)
+ elif context.type == ContextType.IMAGE_CREATE:
+ ok, res = self.create_img(query, 0)
+ if ok:
+ reply = Reply(ReplyType.IMAGE_URL, res)
+ else:
+ reply = Reply(ReplyType.ERROR, res)
+ return reply
+ else:
+ reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+ return reply
+
+ def get_organization_id(self):
+ url = "https://claude.ai/api/organizations"
+ headers = {
+ 'User-Agent':
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
+ 'Accept-Language': 'en-US,en;q=0.5',
+ 'Referer': 'https://claude.ai/chats',
+ 'Content-Type': 'application/json',
+ 'Sec-Fetch-Dest': 'empty',
+ 'Sec-Fetch-Mode': 'cors',
+ 'Sec-Fetch-Site': 'same-origin',
+ 'Connection': 'keep-alive',
+ 'Cookie': f'{self.claude_api_cookie}'
+ }
+ try:
+ response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
+ res = json.loads(response.text)
+ uuid = res[0]['uuid']
+ except:
+ if "App unavailable" in response.text:
+ logger.error("IP error: The IP is not allowed to be used on Claude")
+ self.error = "ip所在地区不被claude支持"
+ elif "Invalid authorization" in response.text:
+ logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
+ self.error = "无法通过claude身份验证,请检查cookie"
+ return None
+ return uuid
+
+ def conversation_share_check(self,session_id):
+ if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
+ con_uuid = conf().get("claude_uuid")
+ return con_uuid
+ if session_id not in self.con_uuid_dic:
+ self.con_uuid_dic[session_id] = self.generate_uuid()
+ self.create_new_chat(self.con_uuid_dic[session_id])
+ return self.con_uuid_dic[session_id]
+
+ def check_cookie(self):
+ flag = self.get_organization_id()
+ return flag
+
+ def create_new_chat(self, con_uuid):
+ """
+ 新建claude对话实体
+ :param con_uuid: 对话id
+ :return:
+ """
+ url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
+ payload = json.dumps({"uuid": con_uuid, "name": ""})
+ headers = {
+ 'User-Agent':
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
+ 'Accept-Language': 'en-US,en;q=0.5',
+ 'Referer': 'https://claude.ai/chats',
+ 'Content-Type': 'application/json',
+ 'Origin': 'https://claude.ai',
+ 'DNT': '1',
+ 'Connection': 'keep-alive',
+ 'Cookie': self.claude_api_cookie,
+ 'Sec-Fetch-Dest': 'empty',
+ 'Sec-Fetch-Mode': 'cors',
+ 'Sec-Fetch-Site': 'same-origin',
+ 'TE': 'trailers'
+ }
+ response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
+ # Returns JSON of the newly created conversation information
+ return response.json()
+
+ def _chat(self, query, context, retry_count=0) -> Reply:
+ """
+ 发起对话请求
+ :param query: 请求提示词
+ :param context: 对话上下文
+ :param retry_count: 当前递归重试次数
+ :return: 回复
+ """
+ if retry_count >= 2:
+ # exit from retry 2 times
+ logger.warn("[CLAUDEAI] failed after maximum number of retry times")
+ return Reply(ReplyType.ERROR, "请再问我一次吧")
+
+ try:
+ session_id = context["session_id"]
+ if self.org_uuid is None:
+ return Reply(ReplyType.ERROR, self.error)
+
+ session = self.sessions.session_query(query, session_id)
+ con_uuid = self.conversation_share_check(session_id)
+
+ model = conf().get("model") or "gpt-3.5-turbo"
+ # remove system message
+ if session.messages[0].get("role") == "system":
+ if model == "wenxin" or model == "claude":
+ session.messages.pop(0)
+ logger.info(f"[CLAUDEAI] query={query}")
+
+ # do http request
+ base_url = "https://claude.ai"
+ payload = json.dumps({
+ "completion": {
+ "prompt": f"{query}",
+ "timezone": "Asia/Kolkata",
+ "model": "claude-2"
+ },
+ "organization_uuid": f"{self.org_uuid}",
+ "conversation_uuid": f"{con_uuid}",
+ "text": f"{query}",
+ "attachments": []
+ })
+ headers = {
+ 'User-Agent':
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
+ 'Accept': 'text/event-stream, text/event-stream',
+ 'Accept-Language': 'en-US,en;q=0.5',
+ 'Referer': 'https://claude.ai/chats',
+ 'Content-Type': 'application/json',
+ 'Origin': 'https://claude.ai',
+ 'DNT': '1',
+ 'Connection': 'keep-alive',
+ 'Cookie': f'{self.claude_api_cookie}',
+ 'Sec-Fetch-Dest': 'empty',
+ 'Sec-Fetch-Mode': 'cors',
+ 'Sec-Fetch-Site': 'same-origin',
+ 'TE': 'trailers'
+ }
+
+ res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
+ if res.status_code == 200 or "pemission" in res.text:
+ # execute success
+ decoded_data = res.content.decode("utf-8")
+ decoded_data = re.sub('\n+', '\n', decoded_data).strip()
+ data_strings = decoded_data.split('\n')
+ completions = []
+ for data_string in data_strings:
+ json_str = data_string[6:].strip()
+ data = json.loads(json_str)
+ if 'completion' in data:
+ completions.append(data['completion'])
+
+ reply_content = ''.join(completions)
+
+ if "rate limi" in reply_content:
+ logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
+ return Reply(ReplyType.ERROR, "对话达到系统速率限制,与cladue同步,请进入官网查看解除限制时间")
+ logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
+ self.sessions.session_reply(reply_content, session_id, 100)
+ return Reply(ReplyType.TEXT, reply_content)
+ else:
+ flag = self.check_cookie()
+ if flag == None:
+ return Reply(ReplyType.ERROR, self.error)
+
+ response = res.json()
+ error = response.get("error")
+ logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
+ f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
+
+ if res.status_code >= 500:
+ # server error, need retry
+ time.sleep(2)
+ logger.warn(f"[CLAUDE] do retry, times={retry_count}")
+ return self._chat(query, context, retry_count + 1)
+ return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
+
+ except Exception as e:
+ logger.exception(e)
+ # retry
+ time.sleep(2)
+ logger.warn(f"[CLAUDE] do retry, times={retry_count}")
+ return self._chat(query, context, retry_count + 1)
diff --git a/bot/claude/claude_ai_session.py b/bot/claude/claude_ai_session.py
new file mode 100644
index 0000000..ede9e51
--- /dev/null
+++ b/bot/claude/claude_ai_session.py
@@ -0,0 +1,9 @@
+from bot.session_manager import Session
+
+
+class ClaudeAiSession(Session):
+ def __init__(self, session_id, system_prompt=None, model="claude"):
+ super().__init__(session_id, system_prompt)
+ self.model = model
+ # claude逆向不支持role prompt
+ # self.reset()
diff --git a/bot/gemini/google_gemini_bot.py b/bot/gemini/google_gemini_bot.py
new file mode 100644
index 0000000..1a49d60
--- /dev/null
+++ b/bot/gemini/google_gemini_bot.py
@@ -0,0 +1,75 @@
+"""
+Google gemini bot
+
+@author zhayujie
+@Date 2023/12/15
+"""
+# encoding:utf-8
+
+from bot.bot import Bot
+import google.generativeai as genai
+from bot.session_manager import SessionManager
+from bridge.context import ContextType, Context
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
+
+
+# OpenAI对话模型API (可用)
+class GoogleGeminiBot(Bot):
+
+ def __init__(self):
+ super().__init__()
+ self.api_key = conf().get("gemini_api_key")
+ # 复用文心的token计算方式
+ self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
+
+ def reply(self, query, context: Context = None) -> Reply:
+ try:
+ if context.type != ContextType.TEXT:
+ logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
+ return Reply(ReplyType.TEXT, None)
+ logger.info(f"[Gemini] query={query}")
+ session_id = context["session_id"]
+ session = self.sessions.session_query(query, session_id)
+ gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
+ genai.configure(api_key=self.api_key)
+ model = genai.GenerativeModel('gemini-pro')
+ response = model.generate_content(gemini_messages)
+ reply_text = response.text
+ self.sessions.session_reply(reply_text, session_id)
+ logger.info(f"[Gemini] reply={reply_text}")
+ return Reply(ReplyType.TEXT, reply_text)
+ except Exception as e:
+ logger.error("[Gemini] fetch reply error, may contain unsafe content")
+ logger.error(e)
+
+ def _convert_to_gemini_messages(self, messages: list):
+ res = []
+ for msg in messages:
+ if msg.get("role") == "user":
+ role = "user"
+ elif msg.get("role") == "assistant":
+ role = "model"
+ else:
+ continue
+ res.append({
+ "role": role,
+ "parts": [{"text": msg.get("content")}]
+ })
+ return res
+
+ def _filter_messages(self, messages: list):
+ res = []
+ turn = "user"
+ for i in range(len(messages) - 1, -1, -1):
+ message = messages[i]
+ if message.get("role") != turn:
+ continue
+ res.insert(0, message)
+ if turn == "user":
+ turn = "assistant"
+ elif turn == "assistant":
+ turn = "user"
+ return res
diff --git a/bot/linkai/link_ai_bot.py b/bot/linkai/link_ai_bot.py
new file mode 100644
index 0000000..d37d82a
--- /dev/null
+++ b/bot/linkai/link_ai_bot.py
@@ -0,0 +1,467 @@
+# access LinkAI knowledge base platform
+# docs: https://link-ai.tech/platform/link-app/wechat
+
+import re
+import time
+import requests
+import config
+from bot.bot import Bot
+from bot.chatgpt.chat_gpt_session import ChatGPTSession
+from bot.session_manager import SessionManager
+from bridge.context import Context, ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf, pconf
+import threading
+from common import memory, utils
+import base64
+import os
+
+class LinkAIBot(Bot):
+ # authentication failed
+ AUTH_FAILED_CODE = 401
+ NO_QUOTA_CODE = 406
+
+ def __init__(self):
+ super().__init__()
+ self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
+ self.args = {}
+
+ def reply(self, query, context: Context = None) -> Reply:
+ if context.type == ContextType.TEXT:
+ return self._chat(query, context)
+ elif context.type == ContextType.IMAGE_CREATE:
+ if not conf().get("text_to_image"):
+ logger.warn("[LinkAI] text_to_image is not enabled, ignore the IMAGE_CREATE request")
+ return Reply(ReplyType.TEXT, "")
+ ok, res = self.create_img(query, 0)
+ if ok:
+ reply = Reply(ReplyType.IMAGE_URL, res)
+ else:
+ reply = Reply(ReplyType.ERROR, res)
+ return reply
+ else:
+ reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+ return reply
+
+ def _chat(self, query, context, retry_count=0) -> Reply:
+ """
+ 发起对话请求
+ :param query: 请求提示词
+ :param context: 对话上下文
+ :param retry_count: 当前递归重试次数
+ :return: 回复
+ """
+ if retry_count > 2:
+ # exit from retry 2 times
+ logger.warn("[LINKAI] failed after maximum number of retry times")
+ return Reply(ReplyType.TEXT, "请再问我一次吧")
+
+ try:
+ # load config
+ if context.get("generate_breaked_by"):
+ logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
+ app_code = None
+ else:
+ plugin_app_code = self._find_group_mapping_code(context)
+ app_code = context.kwargs.get("app_code") or plugin_app_code or conf().get("linkai_app_code")
+ linkai_api_key = conf().get("linkai_api_key")
+
+ session_id = context["session_id"]
+ session_message = self.sessions.session_msg_query(query, session_id)
+ logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")
+
+ # image process
+ img_cache = memory.USER_IMAGE_CACHE.get(session_id)
+ if img_cache:
+ messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
+ if messages:
+ session_message = messages
+
+ model = conf().get("model")
+ # remove system message
+ if session_message[0].get("role") == "system":
+ if app_code or model == "wenxin":
+ session_message.pop(0)
+ body = {
+ "app_code": app_code,
+ "messages": session_message,
+ "model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
+ "temperature": conf().get("temperature"),
+ "top_p": conf().get("top_p", 1),
+ "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "session_id": session_id,
+ "channel_type": conf().get("channel_type")
+ }
+ try:
+ from linkai import LinkAIClient
+ client_id = LinkAIClient.fetch_client_id()
+ if client_id:
+ body["client_id"] = client_id
+ # start: client info deliver
+ if context.kwargs.get("msg"):
+ body["session_id"] = context.kwargs.get("msg").from_user_id
+ if context.kwargs.get("msg").is_group:
+ body["is_group"] = True
+ body["group_name"] = context.kwargs.get("msg").from_user_nickname
+ body["sender_name"] = context.kwargs.get("msg").actual_user_nickname
+ else:
+ if body.get("channel_type") in ["wechatcom_app"]:
+ body["sender_name"] = context.kwargs.get("msg").from_user_id
+ else:
+ body["sender_name"] = context.kwargs.get("msg").from_user_nickname
+
+ except Exception as e:
+ pass
+ file_id = context.kwargs.get("file_id")
+ if file_id:
+ body["file_id"] = file_id
+ logger.info(f"[LINKAI] query={query}, app_code={app_code}, model={body.get('model')}, file_id={file_id}")
+ headers = {"Authorization": "Bearer " + linkai_api_key}
+
+ # do http request
+ base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
+ res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
+ timeout=conf().get("request_timeout", 180))
+ if res.status_code == 200:
+ # execute success
+ response = res.json()
+ reply_content = response["choices"][0]["message"]["content"]
+ total_tokens = response["usage"]["total_tokens"]
+ logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
+ self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
+
+ agent_suffix = self._fetch_agent_suffix(response)
+ if agent_suffix:
+ reply_content += agent_suffix
+ if not agent_suffix:
+ knowledge_suffix = self._fetch_knowledge_search_suffix(response)
+ if knowledge_suffix:
+ reply_content += knowledge_suffix
+ # image process
+ if response["choices"][0].get("img_urls"):
+ thread = threading.Thread(target=self._send_image, args=(context.get("channel"), context, response["choices"][0].get("img_urls")))
+ thread.start()
+ if response["choices"][0].get("text_content"):
+ reply_content = response["choices"][0].get("text_content")
+ reply_content = self._process_url(reply_content)
+ return Reply(ReplyType.TEXT, reply_content)
+
+ else:
+ response = res.json()
+ error = response.get("error")
+ logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
+ f"msg={error.get('message')}, type={error.get('type')}")
+
+ if res.status_code >= 500:
+ # server error, need retry
+ time.sleep(2)
+ logger.warn(f"[LINKAI] do retry, times={retry_count}")
+ return self._chat(query, context, retry_count + 1)
+
+ return Reply(ReplyType.TEXT, "提问太快啦,请休息一下再问我吧")
+
+ except Exception as e:
+ logger.exception(e)
+ # retry
+ time.sleep(2)
+ logger.warn(f"[LINKAI] do retry, times={retry_count}")
+ return self._chat(query, context, retry_count + 1)
+
+ def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
+ try:
+ enable_image_input = False
+ app_info = self._fetch_app_info(app_code)
+ if not app_info:
+ logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
+ return None
+ plugins = app_info.get("data").get("plugins")
+ for plugin in plugins:
+ if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
+ enable_image_input = True
+ if not enable_image_input:
+ return
+ msg = img_cache.get("msg")
+ path = img_cache.get("path")
+ msg.prepare()
+ logger.info(f"[LinkAI] query with images, path={path}")
+ messages = self._build_vision_msg(query, path)
+ memory.USER_IMAGE_CACHE[session_id] = None
+ return messages
+ except Exception as e:
+ logger.exception(e)
+
+ def _find_group_mapping_code(self, context):
+ try:
+ if context.kwargs.get("isgroup"):
+ group_name = context.kwargs.get("msg").from_user_nickname
+ if config.plugin_config and config.plugin_config.get("linkai"):
+ linkai_config = config.plugin_config.get("linkai")
+ group_mapping = linkai_config.get("group_app_map")
+ if group_mapping and group_name:
+ return group_mapping.get(group_name)
+ except Exception as e:
+ logger.exception(e)
+ return None
+
+ def _build_vision_msg(self, query: str, path: str):
+ try:
+ suffix = utils.get_path_suffix(path)
+ with open(path, "rb") as file:
+ base64_str = base64.b64encode(file.read()).decode('utf-8')
+ messages = [{
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": query
+ },
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/{suffix};base64,{base64_str}"
+ }
+ }
+ ]
+ }]
+ return messages
+ except Exception as e:
+ logger.exception(e)
+
+ def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
+ if retry_count >= 2:
+ # exit from retry 2 times
+ logger.warn("[LINKAI] failed after maximum number of retry times")
+ return {
+ "total_tokens": 0,
+ "completion_tokens": 0,
+ "content": "请再问我一次吧"
+ }
+
+ try:
+ body = {
+ "app_code": app_code,
+ "messages": session.messages,
+ "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
+ "temperature": conf().get("temperature"),
+ "top_p": conf().get("top_p", 1),
+ "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ }
+ if self.args.get("max_tokens"):
+ body["max_tokens"] = self.args.get("max_tokens")
+ headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
+
+ # do http request
+ base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
+ res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
+ timeout=conf().get("request_timeout", 180))
+ if res.status_code == 200:
+ # execute success
+ response = res.json()
+ reply_content = response["choices"][0]["message"]["content"]
+ total_tokens = response["usage"]["total_tokens"]
+ logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
+ return {
+ "total_tokens": total_tokens,
+ "completion_tokens": response["usage"]["completion_tokens"],
+ "content": reply_content,
+ }
+
+ else:
+ response = res.json()
+ error = response.get("error")
+ logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
+ f"msg={error.get('message')}, type={error.get('type')}")
+
+ if res.status_code >= 500:
+ # server error, need retry
+ time.sleep(2)
+ logger.warn(f"[LINKAI] do retry, times={retry_count}")
+ return self.reply_text(session, app_code, retry_count + 1)
+
+ return {
+ "total_tokens": 0,
+ "completion_tokens": 0,
+ "content": "提问太快啦,请休息一下再问我吧"
+ }
+
+ except Exception as e:
+ logger.exception(e)
+ # retry
+ time.sleep(2)
+ logger.warn(f"[LINKAI] do retry, times={retry_count}")
+ return self.reply_text(session, app_code, retry_count + 1)
+
+ def _fetch_app_info(self, app_code: str):
+ headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
+ # do http request
+ base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
+ params = {"app_code": app_code}
+ res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
+ if res.status_code == 200:
+ return res.json()
+ else:
+ logger.warning(f"[LinkAI] find app info exception, res={res}")
+
+ def create_img(self, query, retry_count=0, api_key=None):
+ try:
+ logger.info("[LinkImage] image_query={}".format(query))
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {conf().get('linkai_api_key')}"
+ }
+ data = {
+ "prompt": query,
+ "n": 1,
+ "model": conf().get("text_to_image") or "dall-e-2",
+ "response_format": "url",
+ "img_proxy": conf().get("image_proxy")
+ }
+ url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/images/generations"
+ res = requests.post(url, headers=headers, json=data, timeout=(5, 90))
+ t2 = time.time()
+ image_url = res.json()["data"][0]["url"]
+ logger.info("[OPEN_AI] image_url={}".format(image_url))
+ return True, image_url
+
+ except Exception as e:
+ logger.error(format(e))
+ return False, "画图出现问题,请休息一下再问我吧"
+
+
+ def _fetch_knowledge_search_suffix(self, response) -> str:
+ try:
+ if response.get("knowledge_base"):
+ search_hit = response.get("knowledge_base").get("search_hit")
+ first_similarity = response.get("knowledge_base").get("first_similarity")
+ logger.info(f"[LINKAI] knowledge base, search_hit={search_hit}, first_similarity={first_similarity}")
+ plugin_config = pconf("linkai")
+ if plugin_config and plugin_config.get("knowledge_base") and plugin_config.get("knowledge_base").get("search_miss_text_enabled"):
+ search_miss_similarity = plugin_config.get("knowledge_base").get("search_miss_similarity")
+ search_miss_text = plugin_config.get("knowledge_base").get("search_miss_suffix")
+ if not search_hit:
+ return search_miss_text
+ if search_miss_similarity and float(search_miss_similarity) > first_similarity:
+ return search_miss_text
+ except Exception as e:
+ logger.exception(e)
+
+
+ def _fetch_agent_suffix(self, response):
+ try:
+ plugin_list = []
+ logger.debug(f"[LinkAgent] res={response}")
+ if response.get("agent") and response.get("agent").get("chain") and response.get("agent").get("need_show_plugin"):
+ chain = response.get("agent").get("chain")
+ suffix = "\n\n- - - - - - - - - - - -"
+ i = 0
+ for turn in chain:
+ plugin_name = turn.get('plugin_name')
+ suffix += "\n"
+ need_show_thought = response.get("agent").get("need_show_thought")
+ if turn.get("thought") and plugin_name and need_show_thought:
+ suffix += f"{turn.get('thought')}\n"
+ if plugin_name:
+ plugin_list.append(turn.get('plugin_name'))
+ if turn.get('plugin_icon'):
+ suffix += f"{turn.get('plugin_icon')} "
+ suffix += f"{turn.get('plugin_name')}"
+ if turn.get('plugin_input'):
+ suffix += f":{turn.get('plugin_input')}"
+ if i < len(chain) - 1:
+ suffix += "\n"
+ i += 1
+ logger.info(f"[LinkAgent] use plugins: {plugin_list}")
+ return suffix
+ except Exception as e:
+ logger.exception(e)
+
+ def _process_url(self, text):
+ try:
+ url_pattern = re.compile(r'\[(.*?)\]\((http[s]?://.*?)\)')
+ def replace_markdown_url(match):
+ return f"{match.group(2)}"
+ return url_pattern.sub(replace_markdown_url, text)
+ except Exception as e:
+ logger.error(e)
+
+ def _send_image(self, channel, context, image_urls):
+ if not image_urls:
+ return
+ max_send_num = conf().get("max_media_send_count")
+ send_interval = conf().get("media_send_interval")
+ try:
+ i = 0
+ for url in image_urls:
+ if max_send_num and i >= max_send_num:
+ continue
+ i += 1
+ if url.endswith(".mp4"):
+ reply_type = ReplyType.VIDEO_URL
+ elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"):
+ reply_type = ReplyType.FILE
+ url = _download_file(url)
+ if not url:
+ continue
+ else:
+ reply_type = ReplyType.IMAGE_URL
+ reply = Reply(reply_type, url)
+ channel.send(reply, context)
+ if send_interval:
+ time.sleep(send_interval)
+ except Exception as e:
+ logger.error(e)
+
+
+def _download_file(url: str):
+ try:
+ file_path = "tmp"
+ if not os.path.exists(file_path):
+ os.makedirs(file_path)
+ file_name = url.split("/")[-1] # 获取文件名
+ file_path = os.path.join(file_path, file_name)
+ response = requests.get(url)
+ with open(file_path, "wb") as f:
+ f.write(response.content)
+ return file_path
+ except Exception as e:
+ logger.warn(e)
+
+
+class LinkAISessionManager(SessionManager):
+ def session_msg_query(self, query, session_id):
+ session = self.build_session(session_id)
+ messages = session.messages + [{"role": "user", "content": query}]
+ return messages
+
+ def session_reply(self, reply, session_id, total_tokens=None, query=None):
+ session = self.build_session(session_id)
+ if query:
+ session.add_query(query)
+ session.add_reply(reply)
+ try:
+ max_tokens = conf().get("conversation_max_tokens", 2500)
+ tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
+ logger.debug(f"[LinkAI] chat history, before tokens={total_tokens}, now tokens={tokens_cnt}")
+ except Exception as e:
+ logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
+ return session
+
+
+class LinkAISession(ChatGPTSession):
+ def calc_tokens(self):
+ if not self.messages:
+ return 0
+ return len(str(self.messages))
+
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
+ cur_tokens = self.calc_tokens()
+ if cur_tokens > max_tokens:
+ for i in range(0, len(self.messages)):
+ if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
+ self.messages.pop(i)
+ self.messages.pop(i - 1)
+ return self.calc_tokens()
+ return cur_tokens
diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py
new file mode 100644
index 0000000..1605625
--- /dev/null
+++ b/bot/openai/open_ai_bot.py
@@ -0,0 +1,122 @@
+# encoding:utf-8
+
+import time
+
+import openai
+import openai.error
+
+from bot.bot import Bot
+from bot.openai.open_ai_image import OpenAIImage
+from bot.openai.open_ai_session import OpenAISession
+from bot.session_manager import SessionManager
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+
+user_session = dict()
+
+
+# OpenAI对话模型API (可用)
+class OpenAIBot(Bot, OpenAIImage):
+ def __init__(self):
+ super().__init__()
+ openai.api_key = conf().get("open_ai_api_key")
+ if conf().get("open_ai_api_base"):
+ openai.api_base = conf().get("open_ai_api_base")
+ proxy = conf().get("proxy")
+ if proxy:
+ openai.proxy = proxy
+
+ self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
+ self.args = {
+ "model": conf().get("model") or "text-davinci-003", # 对话模型的名称
+ "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
+ "max_tokens": 1200, # 回复最大的字符数
+ "top_p": 1,
+ "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+ "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
+ "stop": ["\n\n\n"],
+ }
+
+ def reply(self, query, context=None):
+ # acquire reply content
+ if context and context.type:
+ if context.type == ContextType.TEXT:
+ logger.info("[OPEN_AI] query={}".format(query))
+ session_id = context["session_id"]
+ reply = None
+ if query == "#清除记忆":
+ self.sessions.clear_session(session_id)
+ reply = Reply(ReplyType.INFO, "记忆已清除")
+ elif query == "#清除所有":
+ self.sessions.clear_all_session()
+ reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+ else:
+ session = self.sessions.session_query(query, session_id)
+ result = self.reply_text(session)
+ total_tokens, completion_tokens, reply_content = (
+ result["total_tokens"],
+ result["completion_tokens"],
+ result["content"],
+ )
+ logger.debug(
+ "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
+ )
+
+ if total_tokens == 0:
+ reply = Reply(ReplyType.ERROR, reply_content)
+ else:
+ self.sessions.session_reply(reply_content, session_id, total_tokens)
+ reply = Reply(ReplyType.TEXT, reply_content)
+ return reply
+ elif context.type == ContextType.IMAGE_CREATE:
+ ok, retstring = self.create_img(query, 0)
+ reply = None
+ if ok:
+ reply = Reply(ReplyType.IMAGE_URL, retstring)
+ else:
+ reply = Reply(ReplyType.ERROR, retstring)
+ return reply
+
+ def reply_text(self, session: OpenAISession, retry_count=0):
+ try:
+ response = openai.Completion.create(prompt=str(session), **self.args)
+ res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
+ total_tokens = response["usage"]["total_tokens"]
+ completion_tokens = response["usage"]["completion_tokens"]
+ logger.info("[OPEN_AI] reply={}".format(res_content))
+ return {
+ "total_tokens": total_tokens,
+ "completion_tokens": completion_tokens,
+ "content": res_content,
+ }
+ except Exception as e:
+ need_retry = retry_count < 2
+ result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+ if isinstance(e, openai.error.RateLimitError):
+ logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
+ result["content"] = "提问太快啦,请休息一下再问我吧"
+ if need_retry:
+ time.sleep(20)
+ elif isinstance(e, openai.error.Timeout):
+ logger.warn("[OPEN_AI] Timeout: {}".format(e))
+ result["content"] = "我没有收到你的消息"
+ if need_retry:
+ time.sleep(5)
+ elif isinstance(e, openai.error.APIConnectionError):
+ logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
+ need_retry = False
+ result["content"] = "我连接不到你的网络"
+ else:
+ logger.warn("[OPEN_AI] Exception: {}".format(e))
+ need_retry = False
+ self.sessions.clear_session(session.session_id)
+
+ if need_retry:
+ logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
+ return self.reply_text(session, retry_count + 1)
+ else:
+ return result
diff --git a/bot/openai/open_ai_image.py b/bot/openai/open_ai_image.py
new file mode 100644
index 0000000..3ff56c1
--- /dev/null
+++ b/bot/openai/open_ai_image.py
@@ -0,0 +1,43 @@
+import time
+
+import openai
+import openai.error
+
+from common.log import logger
+from common.token_bucket import TokenBucket
+from config import conf
+
+
+# OPENAI提供的画图接口
+class OpenAIImage(object):
+ def __init__(self):
+ openai.api_key = conf().get("open_ai_api_key")
+ if conf().get("rate_limit_dalle"):
+ self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
+
+ def create_img(self, query, retry_count=0, api_key=None, api_base=None):
+ try:
+ if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
+ return False, "请求太快了,请休息一下再问我吧"
+ logger.info("[OPEN_AI] image_query={}".format(query))
+ response = openai.Image.create(
+ api_key=api_key,
+ prompt=query, # 图片描述
+ n=1, # 每次生成图片的数量
+ model=conf().get("text_to_image") or "dall-e-2",
+ # size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
+ )
+ image_url = response["data"][0]["url"]
+ logger.info("[OPEN_AI] image_url={}".format(image_url))
+ return True, image_url
+ except openai.error.RateLimitError as e:
+ logger.warn(e)
+ if retry_count < 1:
+ time.sleep(5)
+ logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
+ return self.create_img(query, retry_count + 1)
+ else:
+ return False, "画图出现问题,请休息一下再问我吧"
+ except Exception as e:
+ logger.exception(e)
+ return False, "画图出现问题,请休息一下再问我吧"
diff --git a/bot/openai/open_ai_session.py b/bot/openai/open_ai_session.py
new file mode 100644
index 0000000..8f6aa4f
--- /dev/null
+++ b/bot/openai/open_ai_session.py
@@ -0,0 +1,73 @@
+from bot.session_manager import Session
+from common.log import logger
+
+
+class OpenAISession(Session):
+ def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
+ super().__init__(session_id, system_prompt)
+ self.model = model
+ self.reset()
+
+ def __str__(self):
+ # 构造对话模型的输入
+ """
+ e.g. Q: xxx
+ A: xxx
+ Q: xxx
+ """
+ prompt = ""
+ for item in self.messages:
+ if item["role"] == "system":
+ prompt += item["content"] + "<|endoftext|>\n\n\n"
+ elif item["role"] == "user":
+ prompt += "Q: " + item["content"] + "\n"
+ elif item["role"] == "assistant":
+ prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
+
+ if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
+ prompt += "A: "
+ return prompt
+
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
+ precise = True
+ try:
+ cur_tokens = self.calc_tokens()
+ except Exception as e:
+ precise = False
+ if cur_tokens is None:
+ raise e
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+ while cur_tokens > max_tokens:
+ if len(self.messages) > 1:
+ self.messages.pop(0)
+ elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
+ self.messages.pop(0)
+ if precise:
+ cur_tokens = self.calc_tokens()
+ else:
+ cur_tokens = len(str(self))
+ break
+ elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
+ logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
+ break
+ else:
+ logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
+ break
+ if precise:
+ cur_tokens = self.calc_tokens()
+ else:
+ cur_tokens = len(str(self))
+ return cur_tokens
+
+ def calc_tokens(self):
+ return num_tokens_from_string(str(self), self.model)
+
+
+# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+def num_tokens_from_string(string: str, model: str) -> int:
+ """Returns the number of tokens in a text string."""
+ import tiktoken
+
+ encoding = tiktoken.encoding_for_model(model)
+ num_tokens = len(encoding.encode(string, disallowed_special=()))
+ return num_tokens
diff --git a/bot/session_manager.py b/bot/session_manager.py
new file mode 100644
index 0000000..a6e89f9
--- /dev/null
+++ b/bot/session_manager.py
@@ -0,0 +1,91 @@
+from common.expired_dict import ExpiredDict
+from common.log import logger
+from config import conf
+
+
+class Session(object):
+ def __init__(self, session_id, system_prompt=None):
+ self.session_id = session_id
+ self.messages = []
+ if system_prompt is None:
+ self.system_prompt = conf().get("character_desc", "")
+ else:
+ self.system_prompt = system_prompt
+
+ # 重置会话
+ def reset(self):
+ system_item = {"role": "system", "content": self.system_prompt}
+ self.messages = [system_item]
+
+ def set_system_prompt(self, system_prompt):
+ self.system_prompt = system_prompt
+ self.reset()
+
+ def add_query(self, query):
+ user_item = {"role": "user", "content": query}
+ self.messages.append(user_item)
+
+ def add_reply(self, reply):
+ assistant_item = {"role": "assistant", "content": reply}
+ self.messages.append(assistant_item)
+
+ def discard_exceeding(self, max_tokens=None, cur_tokens=None):
+ raise NotImplementedError
+
+ def calc_tokens(self):
+ raise NotImplementedError
+
+
+class SessionManager(object):
+ def __init__(self, sessioncls, **session_args):
+ if conf().get("expires_in_seconds"):
+ sessions = ExpiredDict(conf().get("expires_in_seconds"))
+ else:
+ sessions = dict()
+ self.sessions = sessions
+ self.sessioncls = sessioncls
+ self.session_args = session_args
+
+ def build_session(self, session_id, system_prompt=None):
+ """
+ 如果session_id不在sessions中,创建一个新的session并添加到sessions中
+ 如果system_prompt不会空,会更新session的system_prompt并重置session
+ """
+ if session_id is None:
+ return self.sessioncls(session_id, system_prompt, **self.session_args)
+
+ if session_id not in self.sessions:
+ self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
+ elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
+ self.sessions[session_id].set_system_prompt(system_prompt)
+ session = self.sessions[session_id]
+ return session
+
+ def session_query(self, query, session_id):
+ session = self.build_session(session_id)
+ session.add_query(query)
+ try:
+ max_tokens = conf().get("conversation_max_tokens", 1000)
+ total_tokens = session.discard_exceeding(max_tokens, None)
+ logger.debug("prompt tokens used={}".format(total_tokens))
+ except Exception as e:
+ logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
+ return session
+
+ def session_reply(self, reply, session_id, total_tokens=None):
+ session = self.build_session(session_id)
+ session.add_reply(reply)
+ try:
+ max_tokens = conf().get("conversation_max_tokens", 1000)
+ tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
+ logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
+ except Exception as e:
+ logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
+ return session
+
+ def clear_session(self, session_id):
+ if session_id in self.sessions:
+ del self.sessions[session_id]
+
+ def clear_all_session(self):
+ self.sessions.clear()
diff --git a/bot/xunfei/xunfei_spark_bot.py b/bot/xunfei/xunfei_spark_bot.py
new file mode 100644
index 0000000..395d81e
--- /dev/null
+++ b/bot/xunfei/xunfei_spark_bot.py
@@ -0,0 +1,267 @@
+# encoding:utf-8
+
+import requests, json
+from bot.bot import Bot
+from bot.session_manager import SessionManager
+from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
+from bridge.context import ContextType, Context
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+from common import const
+import time
+import _thread as thread
+import datetime
+from datetime import datetime
+from wsgiref.handlers import format_date_time
+from urllib.parse import urlencode
+import base64
+import ssl
+import hashlib
+import hmac
+import json
+from time import mktime
+from urllib.parse import urlparse
+import websocket
+import queue
+import threading
+import random
+
+# 消息队列 map
+queue_map = dict()
+
+# 响应队列 map
+reply_map = dict()
+
+
+class XunFeiBot(Bot):
+ def __init__(self):
+ super().__init__()
+ self.app_id = conf().get("xunfei_app_id")
+ self.api_key = conf().get("xunfei_api_key")
+ self.api_secret = conf().get("xunfei_api_secret")
+ # 默认使用v2.0版本: "generalv2"
+ # v1.5版本为 "general"
+ # v3.0版本为: "generalv3"
+ self.domain = "generalv3"
+ # 默认使用v2.0版本: "ws://spark-api.xf-yun.com/v2.1/chat"
+ # v1.5版本为: "ws://spark-api.xf-yun.com/v1.1/chat"
+ # v3.0版本为: "ws://spark-api.xf-yun.com/v3.1/chat"
+ self.spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"
+ self.host = urlparse(self.spark_url).netloc
+ self.path = urlparse(self.spark_url).path
+ # 和wenxin使用相同的session机制
+ self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
+
+ def reply(self, query, context: Context = None) -> Reply:
+ if context.type == ContextType.TEXT:
+ logger.info("[XunFei] query={}".format(query))
+ session_id = context["session_id"]
+ request_id = self.gen_request_id(session_id)
+ reply_map[request_id] = ""
+ session = self.sessions.session_query(query, session_id)
+ threading.Thread(target=self.create_web_socket,
+ args=(session.messages, request_id)).start()
+ depth = 0
+ time.sleep(0.1)
+ t1 = time.time()
+ usage = {}
+ while depth <= 300:
+ try:
+ data_queue = queue_map.get(request_id)
+ if not data_queue:
+ depth += 1
+ time.sleep(0.1)
+ continue
+ data_item = data_queue.get(block=True, timeout=0.1)
+ if data_item.is_end:
+ # 请求结束
+ del queue_map[request_id]
+ if data_item.reply:
+ reply_map[request_id] += data_item.reply
+ usage = data_item.usage
+ break
+
+ reply_map[request_id] += data_item.reply
+ depth += 1
+ except Exception as e:
+ depth += 1
+ continue
+ t2 = time.time()
+ logger.info(
+ f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}"
+ )
+ self.sessions.session_reply(reply_map[request_id], session_id,
+ usage.get("total_tokens"))
+ reply = Reply(ReplyType.TEXT, reply_map[request_id])
+ del reply_map[request_id]
+ return reply
+ else:
+ reply = Reply(ReplyType.ERROR,
+ "Bot不支持处理{}类型的消息".format(context.type))
+ return reply
+
+ def create_web_socket(self, prompt, session_id, temperature=0.5):
+ logger.info(f"[XunFei] start connect, prompt={prompt}")
+ websocket.enableTrace(False)
+ wsUrl = self.create_url()
+ ws = websocket.WebSocketApp(wsUrl,
+ on_message=on_message,
+ on_error=on_error,
+ on_close=on_close,
+ on_open=on_open)
+ data_queue = queue.Queue(1000)
+ queue_map[session_id] = data_queue
+ ws.appid = self.app_id
+ ws.question = prompt
+ ws.domain = self.domain
+ ws.session_id = session_id
+ ws.temperature = temperature
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
+
+ def gen_request_id(self, session_id: str):
+ return session_id + "_" + str(int(time.time())) + "" + str(
+ random.randint(0, 100))
+
+ # 生成url
+ def create_url(self):
+ # 生成RFC1123格式的时间戳
+ now = datetime.now()
+ date = format_date_time(mktime(now.timetuple()))
+
+ # 拼接字符串
+ signature_origin = "host: " + self.host + "\n"
+ signature_origin += "date: " + date + "\n"
+ signature_origin += "GET " + self.path + " HTTP/1.1"
+
+ # 进行hmac-sha256进行加密
+ signature_sha = hmac.new(self.api_secret.encode('utf-8'),
+ signature_origin.encode('utf-8'),
+ digestmod=hashlib.sha256).digest()
+
+ signature_sha_base64 = base64.b64encode(signature_sha).decode(
+ encoding='utf-8')
+
+ authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
+ f'signature="{signature_sha_base64}"'
+
+ authorization = base64.b64encode(
+ authorization_origin.encode('utf-8')).decode(encoding='utf-8')
+
+ # 将请求的鉴权参数组合为字典
+ v = {"authorization": authorization, "date": date, "host": self.host}
+ # 拼接鉴权参数,生成url
+ url = self.spark_url + '?' + urlencode(v)
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
+ return url
+
+ def gen_params(self, appid, domain, question):
+ """
+ 通过appid和用户的提问来生成请参数
+ """
+ data = {
+ "header": {
+ "app_id": appid,
+ "uid": "1234"
+ },
+ "parameter": {
+ "chat": {
+ "domain": domain,
+ "random_threshold": 0.5,
+ "max_tokens": 2048,
+ "auditing": "default"
+ }
+ },
+ "payload": {
+ "message": {
+ "text": question
+ }
+ }
+ }
+ return data
+
+
+class ReplyItem:
+ def __init__(self, reply, usage=None, is_end=False):
+ self.is_end = is_end
+ self.reply = reply
+ self.usage = usage
+
+
+# 收到websocket错误的处理
+def on_error(ws, error):
+ logger.error(f"[XunFei] error: {str(error)}")
+
+
+# 收到websocket关闭的处理
+def on_close(ws, one, two):
+ data_queue = queue_map.get(ws.session_id)
+ data_queue.put("END")
+
+
+# 收到websocket连接建立的处理
+def on_open(ws):
+ logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
+ thread.start_new_thread(run, (ws, ))
+
+
+def run(ws, *args):
+ data = json.dumps(
+ gen_params(appid=ws.appid,
+ domain=ws.domain,
+ question=ws.question,
+ temperature=ws.temperature))
+ ws.send(data)
+
+
+# Websocket 操作
+# 收到websocket消息的处理
+def on_message(ws, message):
+ data = json.loads(message)
+ code = data['header']['code']
+ if code != 0:
+ logger.error(f'请求错误: {code}, {data}')
+ ws.close()
+ else:
+ choices = data["payload"]["choices"]
+ status = choices["status"]
+ content = choices["text"][0]["content"]
+ data_queue = queue_map.get(ws.session_id)
+ if not data_queue:
+ logger.error(
+ f"[XunFei] can't find data queue, session_id={ws.session_id}")
+ return
+ reply_item = ReplyItem(content)
+ if status == 2:
+ usage = data["payload"].get("usage")
+ reply_item = ReplyItem(content, usage)
+ reply_item.is_end = True
+ ws.close()
+ data_queue.put(reply_item)
+
+
+def gen_params(appid, domain, question, temperature=0.5):
+ """
+ 通过appid和用户的提问来生成请参数
+ """
+ data = {
+ "header": {
+ "app_id": appid,
+ "uid": "1234"
+ },
+ "parameter": {
+ "chat": {
+ "domain": domain,
+ "temperature": temperature,
+ "random_threshold": 0.5,
+ "max_tokens": 2048,
+ "auditing": "default"
+ }
+ },
+ "payload": {
+ "message": {
+ "text": question
+ }
+ }
+ }
+ return data
diff --git a/bot/zhipuai/zhipu_ai_image.py b/bot/zhipuai/zhipu_ai_image.py
new file mode 100644
index 0000000..84eb567
--- /dev/null
+++ b/bot/zhipuai/zhipu_ai_image.py
@@ -0,0 +1,29 @@
+from common.log import logger
+from config import conf
+
+
+# ZhipuAI提供的画图接口
+
+class ZhipuAIImage(object):
+ def __init__(self):
+ from zhipuai import ZhipuAI
+ self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
+
+ def create_img(self, query, retry_count=0, api_key=None, api_base=None):
+ try:
+ if conf().get("rate_limit_dalle"):
+ return False, "请求太快了,请休息一下再问我吧"
+ logger.info("[ZHIPU_AI] image_query={}".format(query))
+ response = self.client.images.generations(
+ prompt=query,
+ n=1, # 每次生成图片的数量
+ model=conf().get("text_to_image") or "cogview-3",
+ size=conf().get("image_create_size", "1024x1024"), # 图片大小,可选有 256x256, 512x512, 1024x1024
+ quality="standard",
+ )
+ image_url = response.data[0].url
+ logger.info("[ZHIPU_AI] image_url={}".format(image_url))
+ return True, image_url
+ except Exception as e:
+ logger.exception(e)
+ return False, "画图出现问题,请休息一下再问我吧"
diff --git a/bot/zhipuai/zhipu_ai_session.py b/bot/zhipuai/zhipu_ai_session.py
new file mode 100644
index 0000000..394d521
--- /dev/null
+++ b/bot/zhipuai/zhipu_ai_session.py
@@ -0,0 +1,51 @@
+from bot.session_manager import Session
+from common.log import logger
+
+
+class ZhipuAISession(Session):
+ def __init__(self, session_id, system_prompt=None, model="glm-4"):
+ super().__init__(session_id, system_prompt)
+ self.model = model
+ self.reset()
+
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
+ precise = True
+ try:
+ cur_tokens = self.calc_tokens()
+ except Exception as e:
+ precise = False
+ if cur_tokens is None:
+ raise e
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
+ while cur_tokens > max_tokens:
+ if len(self.messages) > 2:
+ self.messages.pop(1)
+ elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
+ self.messages.pop(1)
+ if precise:
+ cur_tokens = self.calc_tokens()
+ else:
+ cur_tokens = cur_tokens - max_tokens
+ break
+ elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
+ logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
+ break
+ else:
+ logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
+ len(self.messages)))
+ break
+ if precise:
+ cur_tokens = self.calc_tokens()
+ else:
+ cur_tokens = cur_tokens - max_tokens
+ return cur_tokens
+
+ def calc_tokens(self):
+ return num_tokens_from_messages(self.messages, self.model)
+
+
+def num_tokens_from_messages(messages, model):
+ tokens = 0
+ for msg in messages:
+ tokens += len(msg["content"])
+ return tokens
diff --git a/bot/zhipuai/zhipuai_bot.py b/bot/zhipuai/zhipuai_bot.py
new file mode 100644
index 0000000..d8eed4d
--- /dev/null
+++ b/bot/zhipuai/zhipuai_bot.py
@@ -0,0 +1,149 @@
+# encoding:utf-8
+
+import time
+
+import openai
+import openai.error
+from bot.bot import Bot
+from bot.zhipuai.zhipu_ai_session import ZhipuAISession
+from bot.zhipuai.zhipu_ai_image import ZhipuAIImage
+from bot.session_manager import SessionManager
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf, load_config
+from zhipuai import ZhipuAI
+
+
+# ZhipuAI对话模型API
+class ZHIPUAIBot(Bot, ZhipuAIImage):
+ def __init__(self):
+ super().__init__()
+ self.sessions = SessionManager(ZhipuAISession, model=conf().get("model") or "ZHIPU_AI")
+ self.args = {
+ "model": conf().get("model") or "glm-4", # 对话模型的名称
+ "temperature": conf().get("temperature", 0.9), # 值在(0,1)之间(智谱AI 的温度不能取 0 或者 1)
+ "top_p": conf().get("top_p", 0.7), # 值在(0,1)之间(智谱AI 的 top_p 不能取 0 或者 1)
+ }
+ self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
+
+ def reply(self, query, context=None):
+ # acquire reply content
+ if context.type == ContextType.TEXT:
+ logger.info("[ZHIPU_AI] query={}".format(query))
+
+ session_id = context["session_id"]
+ reply = None
+ clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
+ if query in clear_memory_commands:
+ self.sessions.clear_session(session_id)
+ reply = Reply(ReplyType.INFO, "记忆已清除")
+ elif query == "#清除所有":
+ self.sessions.clear_all_session()
+ reply = Reply(ReplyType.INFO, "所有人记忆已清除")
+ elif query == "#更新配置":
+ load_config()
+ reply = Reply(ReplyType.INFO, "配置已更新")
+ if reply:
+ return reply
+ session = self.sessions.session_query(query, session_id)
+ logger.debug("[ZHIPU_AI] session query={}".format(session.messages))
+
+ api_key = context.get("openai_api_key") or openai.api_key
+ model = context.get("gpt_model")
+ new_args = None
+ if model:
+ new_args = self.args.copy()
+ new_args["model"] = model
+ # if context.get('stream'):
+ # # reply in stream
+ # return self.reply_text_stream(query, new_query, session_id)
+
+ reply_content = self.reply_text(session, api_key, args=new_args)
+ logger.debug(
+ "[ZHIPU_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
+ session.messages,
+ session_id,
+ reply_content["content"],
+ reply_content["completion_tokens"],
+ )
+ )
+ if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
+ reply = Reply(ReplyType.ERROR, reply_content["content"])
+ elif reply_content["completion_tokens"] > 0:
+ self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
+ reply = Reply(ReplyType.TEXT, reply_content["content"])
+ else:
+ reply = Reply(ReplyType.ERROR, reply_content["content"])
+ logger.debug("[ZHIPU_AI] reply {} used 0 tokens.".format(reply_content))
+ return reply
+ elif context.type == ContextType.IMAGE_CREATE:
+ ok, retstring = self.create_img(query, 0)
+ reply = None
+ if ok:
+ reply = Reply(ReplyType.IMAGE_URL, retstring)
+ else:
+ reply = Reply(ReplyType.ERROR, retstring)
+ return reply
+
+ else:
+ reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
+ return reply
+
+ def reply_text(self, session: ZhipuAISession, api_key=None, args=None, retry_count=0) -> dict:
+ """
+ call openai's ChatCompletion to get the answer
+ :param session: a conversation session
+ :param session_id: session id
+ :param retry_count: retry count
+ :return: {}
+ """
+ try:
+ # if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
+ # raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
+ # if api_key == None, the default openai.api_key will be used
+ if args is None:
+ args = self.args
+ # response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
+ response = self.client.chat.completions.create(messages=session.messages, **args)
+ # logger.debug("[ZHIPU_AI] response={}".format(response))
+ # logger.info("[ZHIPU_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
+
+ return {
+ "total_tokens": response.usage.total_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "content": response.choices[0].message.content,
+ }
+ except Exception as e:
+ need_retry = retry_count < 2
+ result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
+ if isinstance(e, openai.error.RateLimitError):
+ logger.warn("[ZHIPU_AI] RateLimitError: {}".format(e))
+ result["content"] = "提问太快啦,请休息一下再问我吧"
+ if need_retry:
+ time.sleep(20)
+ elif isinstance(e, openai.error.Timeout):
+ logger.warn("[ZHIPU_AI] Timeout: {}".format(e))
+ result["content"] = "我没有收到你的消息"
+ if need_retry:
+ time.sleep(5)
+ elif isinstance(e, openai.error.APIError):
+ logger.warn("[ZHIPU_AI] Bad Gateway: {}".format(e))
+ result["content"] = "请再问我一次"
+ if need_retry:
+ time.sleep(10)
+ elif isinstance(e, openai.error.APIConnectionError):
+ logger.warn("[ZHIPU_AI] APIConnectionError: {}".format(e))
+ result["content"] = "我连接不到你的网络"
+ if need_retry:
+ time.sleep(5)
+ else:
+ logger.exception("[ZHIPU_AI] Exception: {}".format(e), e)
+ need_retry = False
+ self.sessions.clear_session(session.session_id)
+
+ if need_retry:
+ logger.warn("[ZHIPU_AI] 第{}次重试".format(retry_count + 1))
+ return self.reply_text(session, api_key, args, retry_count + 1)
+ else:
+ return result
diff --git a/bridge/bridge.py b/bridge/bridge.py
new file mode 100644
index 0000000..88e6b18
--- /dev/null
+++ b/bridge/bridge.py
@@ -0,0 +1,86 @@
+from bot.bot_factory import create_bot
+from bridge.context import Context
+from bridge.reply import Reply
+from common import const
+from common.log import logger
+from common.singleton import singleton
+from config import conf
+from translate.factory import create_translator
+from voice.factory import create_voice
+
+
+@singleton
+class Bridge(object):
+ def __init__(self):
+ self.btype = {
+ "chat": const.CHATGPT,
+ "voice_to_text": conf().get("voice_to_text", "openai"),
+ "text_to_voice": conf().get("text_to_voice", "google"),
+ "translate": conf().get("translate", "baidu"),
+ }
+ model_type = conf().get("model") or const.GPT35
+ if model_type in ["text-davinci-003"]:
+ self.btype["chat"] = const.OPEN_AI
+ if conf().get("use_azure_chatgpt", False):
+ self.btype["chat"] = const.CHATGPTONAZURE
+ if model_type in ["wenxin", "wenxin-4"]:
+ self.btype["chat"] = const.BAIDU
+ if model_type in ["xunfei"]:
+ self.btype["chat"] = const.XUNFEI
+ if model_type in [const.QWEN]:
+ self.btype["chat"] = const.QWEN
+ if model_type in [const.GEMINI]:
+ self.btype["chat"] = const.GEMINI
+ if model_type in [const.ZHIPU_AI]:
+ self.btype["chat"] = const.ZHIPU_AI
+
+ if conf().get("use_linkai") and conf().get("linkai_api_key"):
+ self.btype["chat"] = const.LINKAI
+ if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
+ self.btype["voice_to_text"] = const.LINKAI
+ if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
+ self.btype["text_to_voice"] = const.LINKAI
+
+ if model_type in ["claude"]:
+ self.btype["chat"] = const.CLAUDEAI
+ self.bots = {}
+ self.chat_bots = {}
+
+ def get_bot(self, typename):
+ if self.bots.get(typename) is None:
+ logger.info("create bot {} for {}".format(self.btype[typename], typename))
+ if typename == "text_to_voice":
+ self.bots[typename] = create_voice(self.btype[typename])
+ elif typename == "voice_to_text":
+ self.bots[typename] = create_voice(self.btype[typename])
+ elif typename == "chat":
+ self.bots[typename] = create_bot(self.btype[typename])
+ elif typename == "translate":
+ self.bots[typename] = create_translator(self.btype[typename])
+ return self.bots[typename]
+
+ def get_bot_type(self, typename):
+ return self.btype[typename]
+
+ def fetch_reply_content(self, query, context: Context) -> Reply:
+ return self.get_bot("chat").reply(query, context)
+
+ def fetch_voice_to_text(self, voiceFile) -> Reply:
+ return self.get_bot("voice_to_text").voiceToText(voiceFile)
+
+ def fetch_text_to_voice(self, text) -> Reply:
+ return self.get_bot("text_to_voice").textToVoice(text)
+
+ def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
+ return self.get_bot("translate").translate(text, from_lang, to_lang)
+
+ def find_chat_bot(self, bot_type: str):
+ if self.chat_bots.get(bot_type) is None:
+ self.chat_bots[bot_type] = create_bot(bot_type)
+ return self.chat_bots.get(bot_type)
+
+ def reset_bot(self):
+ """
+ 重置bot路由
+ """
+ self.__init__()
diff --git a/bridge/context.py b/bridge/context.py
new file mode 100644
index 0000000..04d6320
--- /dev/null
+++ b/bridge/context.py
@@ -0,0 +1,71 @@
+# encoding:utf-8
+
+from enum import Enum
+
+
+class ContextType(Enum):
+ TEXT = 1 # 文本消息
+ VOICE = 2 # 音频消息
+ IMAGE = 3 # 图片消息
+ FILE = 4 # 文件信息
+ VIDEO = 5 # 视频信息
+ SHARING = 6 # 分享信息
+
+ IMAGE_CREATE = 10 # 创建图片命令
+ ACCEPT_FRIEND = 19 # 同意好友请求
+ JOIN_GROUP = 20 # 加入群聊
+ PATPAT = 21 # 拍了拍
+ FUNCTION = 22 # 函数调用
+ EXIT_GROUP = 23 #退出
+
+
+ def __str__(self):
+ return self.name
+
+
+class Context:
+ def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
+ self.type = type
+ self.content = content
+ self.kwargs = kwargs
+
+ def __contains__(self, key):
+ if key == "type":
+ return self.type is not None
+ elif key == "content":
+ return self.content is not None
+ else:
+ return key in self.kwargs
+
+ def __getitem__(self, key):
+ if key == "type":
+ return self.type
+ elif key == "content":
+ return self.content
+ else:
+ return self.kwargs[key]
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def __setitem__(self, key, value):
+ if key == "type":
+ self.type = value
+ elif key == "content":
+ self.content = value
+ else:
+ self.kwargs[key] = value
+
+ def __delitem__(self, key):
+ if key == "type":
+ self.type = None
+ elif key == "content":
+ self.content = None
+ else:
+ del self.kwargs[key]
+
+ def __str__(self):
+ return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
diff --git a/bridge/reply.py b/bridge/reply.py
new file mode 100644
index 0000000..0031484
--- /dev/null
+++ b/bridge/reply.py
@@ -0,0 +1,31 @@
+# encoding:utf-8
+
+from enum import Enum
+
+
+class ReplyType(Enum):
+ TEXT = 1 # 文本
+ VOICE = 2 # 音频文件
+ IMAGE = 3 # 图片文件
+ IMAGE_URL = 4 # 图片URL
+ VIDEO_URL = 5 # 视频URL
+ FILE = 6 # 文件
+ CARD = 7 # 微信名片,仅支持ntchat
+ InviteRoom = 8 # 邀请好友进群
+ INFO = 9
+ ERROR = 10
+ TEXT_ = 11 # 强制文本
+ VIDEO = 12
+ MINIAPP = 13 # 小程序
+
+ def __str__(self):
+ return self.name
+
+
+class Reply:
+ def __init__(self, type: ReplyType = None, content=None):
+ self.type = type
+ self.content = content
+
+ def __str__(self):
+ return "Reply(type={}, content={})".format(self.type, self.content)
diff --git a/channel/channel.py b/channel/channel.py
new file mode 100644
index 0000000..c225342
--- /dev/null
+++ b/channel/channel.py
@@ -0,0 +1,44 @@
+"""
+Message sending channel abstract class
+"""
+
+from bridge.bridge import Bridge
+from bridge.context import Context
+from bridge.reply import *
+
+
+class Channel(object):
+ channel_type = ""
+ NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
+
+ def startup(self):
+ """
+ init channel
+ """
+ raise NotImplementedError
+
+ def handle_text(self, msg):
+ """
+ process received msg
+ :param msg: message object
+ """
+ raise NotImplementedError
+
+ # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
+ def send(self, reply: Reply, context: Context):
+ """
+ send message to user
+ :param msg: message content
+ :param receiver: receiver channel account
+ :return:
+ """
+ raise NotImplementedError
+
+ def build_reply_content(self, query, context: Context = None) -> Reply:
+ return Bridge().fetch_reply_content(query, context)
+
+ def build_voice_to_text(self, voice_file) -> Reply:
+ return Bridge().fetch_voice_to_text(voice_file)
+
+ def build_text_to_voice(self, text) -> Reply:
+ return Bridge().fetch_text_to_voice(text)
diff --git a/channel/channel_factory.py b/channel/channel_factory.py
new file mode 100644
index 0000000..c2c6937
--- /dev/null
+++ b/channel/channel_factory.py
@@ -0,0 +1,45 @@
+"""
+channel factory
+"""
+from common import const
+from .channel import Channel
+
+
+def create_channel(channel_type) -> Channel:
+ """
+ create a channel instance
+ :param channel_type: channel type code
+ :return: channel instance
+ """
+ ch = Channel()
+ if channel_type == "wx":
+ from channel.wechat.wechat_channel import WechatChannel
+ ch = WechatChannel()
+ elif channel_type == "wxy":
+ from channel.wechat.wechaty_channel import WechatyChannel
+ ch = WechatyChannel()
+ elif channel_type == "terminal":
+ from channel.terminal.terminal_channel import TerminalChannel
+ ch = TerminalChannel()
+ elif channel_type == "wechatmp":
+ from channel.wechatmp.wechatmp_channel import WechatMPChannel
+ ch = WechatMPChannel(passive_reply=True)
+ elif channel_type == "wechatmp_service":
+ from channel.wechatmp.wechatmp_channel import WechatMPChannel
+ ch = WechatMPChannel(passive_reply=False)
+ elif channel_type == "wechatcom_app":
+ from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
+ ch = WechatComAppChannel()
+ elif channel_type == "wework":
+ from channel.wework.wework_channel import WeworkChannel
+ ch = WeworkChannel()
+ elif channel_type == const.FEISHU:
+ from channel.feishu.feishu_channel import FeiShuChanel
+ ch = FeiShuChanel()
+ elif channel_type == const.DINGTALK:
+ from channel.dingtalk.dingtalk_channel import DingTalkChanel
+ ch = DingTalkChanel()
+ else:
+ raise RuntimeError
+ ch.channel_type = channel_type
+ return ch
diff --git a/channel/chat_channel.py b/channel/chat_channel.py
new file mode 100644
index 0000000..fe71207
--- /dev/null
+++ b/channel/chat_channel.py
@@ -0,0 +1,394 @@
+import os
+import re
+import threading
+import time
+from asyncio import CancelledError
+from concurrent.futures import Future, ThreadPoolExecutor
+from concurrent import futures
+
+from bridge.context import *
+from bridge.reply import *
+from channel.channel import Channel
+from common.dequeue import Dequeue
+from common import memory
+from plugins import *
+
+try:
+ from voice.audio_convert import any_to_wav
+except Exception as e:
+ pass
+
+handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
+
+
+# 抽象类, 它包含了与消息通道无关的通用处理逻辑
+class ChatChannel(Channel):
+ name = None # 登录的用户名
+ user_id = None # 登录的用户id
+ futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
+ sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
+ lock = threading.Lock() # 用于控制对sessions的访问
+
+ def __init__(self):
+ _thread = threading.Thread(target=self.consume)
+ _thread.setDaemon(True)
+ _thread.start()
+
+ # 根据消息构造context,消息内容相关的触发项写在这里
+ def _compose_context(self, ctype: ContextType, content, **kwargs):
+ context = Context(ctype, content)
+ context.kwargs = kwargs
+ # context首次传入时,origin_ctype是None,
+ # 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
+ # origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
+ if "origin_ctype" not in context:
+ context["origin_ctype"] = ctype
+ # context首次传入时,receiver是None,根据类型设置receiver
+ first_in = "receiver" not in context
+ # 群名匹配过程,设置session_id和receiver
+ if first_in: # context首次传入时,receiver是None,根据类型设置receiver
+ config = conf()
+ cmsg = context["msg"]
+ user_data = conf().get_user_data(cmsg.from_user_id)
+ context["openai_api_key"] = user_data.get("openai_api_key")
+ context["gpt_model"] = user_data.get("gpt_model")
+ if context.get("isgroup", False):
+ group_name = cmsg.other_user_nickname
+ group_id = cmsg.other_user_id
+
+ group_name_white_list = config.get("group_name_white_list", [])
+ group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
+ if any(
+ [
+ group_name in group_name_white_list,
+ "ALL_GROUP" in group_name_white_list,
+ check_contain(group_name, group_name_keyword_white_list),
+ ]
+ ):
+ group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
+ session_id = cmsg.actual_user_id
+ if any(
+ [
+ group_name in group_chat_in_one_session,
+ "ALL_GROUP" in group_chat_in_one_session,
+ ]
+ ):
+ session_id = group_id
+ else:
+ return None
+ context["session_id"] = session_id
+ context["receiver"] = group_id
+ else:
+ context["session_id"] = cmsg.other_user_id
+ context["receiver"] = cmsg.other_user_id
+ e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
+ context = e_context["context"]
+ if e_context.is_pass() or context is None:
+ return context
+ if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
+ logger.debug("[WX]self message skipped")
+ return None
+
+ # 消息内容匹配过程,并处理content
+ if ctype == ContextType.TEXT:
+ if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
+ logger.debug(content)
+ logger.debug("[WX]reference query skipped")
+ return None
+
+ nick_name_black_list = conf().get("nick_name_black_list", [])
+ if context.get("isgroup", False): # 群聊
+ # 校验关键字
+ match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
+ match_contain = check_contain(content, conf().get("group_chat_keyword"))
+ flag = False
+ if context["msg"].to_user_id != context["msg"].actual_user_id:
+ if match_prefix is not None or match_contain is not None:
+ flag = True
+ if match_prefix:
+ content = content.replace(match_prefix, "", 1).strip()
+ if context["msg"].is_at:
+ nick_name = context["msg"].actual_user_nickname
+ if nick_name and nick_name in nick_name_black_list:
+ # 黑名单过滤
+ logger.warning(f"[WX] Nickname {nick_name} in In BlackList, ignore")
+ return None
+
+ logger.info("[WX]receive group at")
+ if not conf().get("group_at_off", False):
+ flag = True
+ pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
+ subtract_res = re.sub(pattern, r"", content)
+ if isinstance(context["msg"].at_list, list):
+ for at in context["msg"].at_list:
+ pattern = f"@{re.escape(at)}(\u2005|\u0020)"
+ subtract_res = re.sub(pattern, r"", subtract_res)
+ if subtract_res == content and context["msg"].self_display_name:
+ # 前缀移除后没有变化,使用群昵称再次移除
+ pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
+ subtract_res = re.sub(pattern, r"", content)
+ content = subtract_res
+ if not flag:
+ if context["origin_ctype"] == ContextType.VOICE:
+ logger.info("[WX]receive group voice, but checkprefix didn't match")
+ return None
+ else: # 单聊
+ nick_name = context["msg"].from_user_nickname
+ if nick_name and nick_name in nick_name_black_list:
+ # 黑名单过滤
+ logger.warning(f"[WX] Nickname '{nick_name}' in In BlackList, ignore")
+ return None
+
+ match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
+ if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
+ content = content.replace(match_prefix, "", 1).strip()
+ elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
+ pass
+ else:
+ return None
+ content = content.strip()
+ img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
+ if img_match_prefix:
+ content = content.replace(img_match_prefix, "", 1)
+ context.type = ContextType.IMAGE_CREATE
+ else:
+ context.type = ContextType.TEXT
+ context.content = content.strip()
+ if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
+ context["desire_rtype"] = ReplyType.VOICE
+ elif context.type == ContextType.VOICE:
+ if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
+ context["desire_rtype"] = ReplyType.VOICE
+
+ return context
+
+ def _handle(self, context: Context):
+ if context is None or not context.content:
+ return
+ logger.debug("[WX] ready to handle context: {}".format(context))
+ # reply的构建步骤
+ reply = self._generate_reply(context)
+
+ logger.debug("[WX] ready to decorate reply: {}".format(reply))
+ # reply的包装步骤
+ reply = self._decorate_reply(context, reply)
+
+ # reply的发送步骤
+ self._send_reply(context, reply)
+
+ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
+ e_context = PluginManager().emit_event(
+ EventContext(
+ Event.ON_HANDLE_CONTEXT,
+ {"channel": self, "context": context, "reply": reply},
+ )
+ )
+ reply = e_context["reply"]
+ if not e_context.is_pass():
+ logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
+ if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
+ context["channel"] = e_context["channel"]
+ reply = super().build_reply_content(context.content, context)
+ elif context.type == ContextType.VOICE: # 语音消息
+ cmsg = context["msg"]
+ cmsg.prepare()
+ file_path = context.content
+ wav_path = os.path.splitext(file_path)[0] + ".wav"
+ try:
+ any_to_wav(file_path, wav_path)
+ except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
+ logger.warning("[WX]any to wav error, use raw path. " + str(e))
+ wav_path = file_path
+ # 语音识别
+ reply = super().build_voice_to_text(wav_path)
+ # 删除临时文件
+ try:
+ os.remove(file_path)
+ if wav_path != file_path:
+ os.remove(wav_path)
+ except Exception as e:
+ pass
+ # logger.warning("[WX]delete temp file error: " + str(e))
+
+ if reply.type == ReplyType.TEXT:
+ new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
+ if new_context:
+ reply = self._generate_reply(new_context)
+ else:
+ return
+ elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
+ memory.USER_IMAGE_CACHE[context["session_id"]] = {
+ "path": context.content,
+ "msg": context.get("msg")
+ }
+ elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
+ pass
+ elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
+ pass
+ else:
+ logger.warning("[WX] unknown context type: {}".format(context.type))
+ return
+ return reply
+
+ def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
+ if reply and reply.type:
+ e_context = PluginManager().emit_event(
+ EventContext(
+ Event.ON_DECORATE_REPLY,
+ {"channel": self, "context": context, "reply": reply},
+ )
+ )
+ reply = e_context["reply"]
+ desire_rtype = context.get("desire_rtype")
+ if not e_context.is_pass() and reply and reply.type:
+ if reply.type in self.NOT_SUPPORT_REPLYTYPE:
+ logger.error("[WX]reply type not support: " + str(reply.type))
+ reply.type = ReplyType.ERROR
+ reply.content = "不支持发送的消息类型: " + str(reply.type)
+
+ if reply.type == ReplyType.TEXT:
+ reply_text = reply.content
+ if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
+ reply = super().build_text_to_voice(reply.content)
+ return self._decorate_reply(context, reply)
+ if context.get("isgroup", False):
+ if not context.get("no_need_at", False):
+ reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip()
+ reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "")
+ else:
+ reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "")
+ reply.content = reply_text
+ elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
+ reply.content = "[" + str(reply.type) + "]\n" + reply.content
+ elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
+ pass
+ else:
+ logger.error("[WX] unknown reply type: {}".format(reply.type))
+ return
+ if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
+ logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
+ return reply
+
+ def _send_reply(self, context: Context, reply: Reply):
+ if reply and reply.type:
+ e_context = PluginManager().emit_event(
+ EventContext(
+ Event.ON_SEND_REPLY,
+ {"channel": self, "context": context, "reply": reply},
+ )
+ )
+ reply = e_context["reply"]
+ if not e_context.is_pass() and reply and reply.type:
+ logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
+ self._send(reply, context)
+
+ def _send(self, reply: Reply, context: Context, retry_cnt=0):
+ try:
+ self.send(reply, context)
+ except Exception as e:
+ logger.error("[WX] sendMsg error: {}".format(str(e)))
+ if isinstance(e, NotImplementedError):
+ return
+ logger.exception(e)
+ if retry_cnt < 2:
+ time.sleep(3 + 3 * retry_cnt)
+ self._send(reply, context, retry_cnt + 1)
+
+ def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
+ logger.debug("Worker return success, session_id = {}".format(session_id))
+
+ def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
+ logger.exception("Worker return exception: {}".format(exception))
+
+ def _thread_pool_callback(self, session_id, **kwargs):
+ def func(worker: Future):
+ try:
+ worker_exception = worker.exception()
+ if worker_exception:
+ self._fail_callback(session_id, exception=worker_exception, **kwargs)
+ else:
+ self._success_callback(session_id, **kwargs)
+ except CancelledError as e:
+ logger.info("Worker cancelled, session_id = {}".format(session_id))
+ except Exception as e:
+ logger.exception("Worker raise exception: {}".format(e))
+ with self.lock:
+ self.sessions[session_id][1].release()
+
+ return func
+
+ def produce(self, context: Context):
+ session_id = context["session_id"]
+ with self.lock:
+ if session_id not in self.sessions:
+ self.sessions[session_id] = [
+ Dequeue(),
+ threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
+ ]
+ if context.type == ContextType.TEXT and context.content.startswith("#"):
+ self.sessions[session_id][0].putleft(context) # 优先处理管理命令
+ else:
+ self.sessions[session_id][0].put(context)
+
+ # 消费者函数,单独线程,用于从消息队列中取出消息并处理
+ def consume(self):
+ while True:
+ with self.lock:
+ session_ids = list(self.sessions.keys())
+ for session_id in session_ids:
+ context_queue, semaphore = self.sessions[session_id]
+ if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
+ if not context_queue.empty():
+ context = context_queue.get()
+ logger.debug("[WX] consume context: {}".format(context))
+ future: Future = handler_pool.submit(self._handle, context)
+ future.add_done_callback(self._thread_pool_callback(session_id, context=context))
+ if session_id not in self.futures:
+ self.futures[session_id] = []
+ self.futures[session_id].append(future)
+ elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
+ self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
+ assert len(self.futures[session_id]) == 0, "thread pool error"
+ del self.sessions[session_id]
+ else:
+ semaphore.release()
+ time.sleep(0.1)
+
+ # 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
+ def cancel_session(self, session_id):
+ with self.lock:
+ if session_id in self.sessions:
+ for future in self.futures[session_id]:
+ future.cancel()
+ cnt = self.sessions[session_id][0].qsize()
+ if cnt > 0:
+ logger.info("Cancel {} messages in session {}".format(cnt, session_id))
+ self.sessions[session_id][0] = Dequeue()
+
+ def cancel_all_session(self):
+ with self.lock:
+ for session_id in self.sessions:
+ for future in self.futures[session_id]:
+ future.cancel()
+ cnt = self.sessions[session_id][0].qsize()
+ if cnt > 0:
+ logger.info("Cancel {} messages in session {}".format(cnt, session_id))
+ self.sessions[session_id][0] = Dequeue()
+
+
+def check_prefix(content, prefix_list):
+ if not prefix_list:
+ return None
+ for prefix in prefix_list:
+ if content.startswith(prefix):
+ return prefix
+ return None
+
+
+def check_contain(content, keyword_list):
+ if not keyword_list:
+ return None
+ for ky in keyword_list:
+ if content.find(ky) != -1:
+ return True
+ return None
diff --git a/channel/chat_message.py b/channel/chat_message.py
new file mode 100644
index 0000000..ac0e5c2
--- /dev/null
+++ b/channel/chat_message.py
@@ -0,0 +1,87 @@
+"""
+本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
+
+填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
+
+ChatMessage
+msg_id: 消息id (必填)
+create_time: 消息创建时间
+
+ctype: 消息类型 : ContextType (必填)
+content: 消息内容, 如果是声音/图片,这里是文件路径 (必填)
+
+from_user_id: 发送者id (必填)
+from_user_nickname: 发送者昵称
+to_user_id: 接收者id (必填)
+to_user_nickname: 接收者昵称
+
+other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id (必填)
+other_user_nickname: 同上
+
+is_group: 是否是群消息 (群聊必填)
+is_at: 是否被at
+
+- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
+actual_user_id: 实际发送者id (群聊必填)
+actual_user_nickname:实际发送者昵称
+self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
+
+_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
+_prepared: 是否已经调用过准备函数
+_rawmsg: 原始消息对象
+
+"""
+
+
+class ChatMessage(object):
+ msg_id = None
+ create_time = None
+
+ ctype = None
+ content = None
+
+ from_user_id = None
+ from_user_nickname = None
+ to_user_id = None
+ to_user_nickname = None
+ other_user_id = None
+ other_user_nickname = None
+ my_msg = False
+ self_display_name = None
+
+ is_group = False
+ is_at = False
+ actual_user_id = None
+ actual_user_nickname = None
+ at_list = None
+
+ _prepare_fn = None
+ _prepared = False
+ _rawmsg = None
+
+ def __init__(self, _rawmsg):
+ self._rawmsg = _rawmsg
+
+ def prepare(self):
+ if self._prepare_fn and not self._prepared:
+ self._prepared = True
+ self._prepare_fn()
+
+ def __str__(self):
+ return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}, at_list={}".format(
+ self.msg_id,
+ self.create_time,
+ self.ctype,
+ self.content,
+ self.from_user_id,
+ self.from_user_nickname,
+ self.to_user_id,
+ self.to_user_nickname,
+ self.other_user_id,
+ self.other_user_nickname,
+ self.is_group,
+ self.is_at,
+ self.actual_user_id,
+ self.actual_user_nickname,
+ self.at_list
+ )
diff --git a/channel/dingtalk/dingtalk_channel.py b/channel/dingtalk/dingtalk_channel.py
new file mode 100644
index 0000000..22ef889
--- /dev/null
+++ b/channel/dingtalk/dingtalk_channel.py
@@ -0,0 +1,100 @@
+"""
+钉钉通道接入
+
+@author huiwen
+@Date 2023/11/28
+"""
+
+# -*- coding=utf-8 -*-
+from channel.dingtalk.dingtalk_message import DingTalkMessage
+from bridge.context import Context
+from bridge.reply import Reply
+from common.log import logger
+from common.singleton import singleton
+from config import conf
+from common.expired_dict import ExpiredDict
+from bridge.context import ContextType
+from channel.chat_channel import ChatChannel
+import logging
+from dingtalk_stream import AckMessage
+import dingtalk_stream
+
+
+@singleton
+class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
+ dingtalk_client_id = conf().get('dingtalk_client_id')
+ dingtalk_client_secret = conf().get('dingtalk_client_secret')
+
+ def setup_logger(self):
+ logger = logging.getLogger()
+ handler = logging.StreamHandler()
+ handler.setFormatter(
+ logging.Formatter('%(asctime)s %(name)-8s %(levelname)-8s %(message)s [%(filename)s:%(lineno)d]'))
+ logger.addHandler(handler)
+ logger.setLevel(logging.INFO)
+ return logger
+
+ def __init__(self):
+ super().__init__()
+ super(dingtalk_stream.ChatbotHandler, self).__init__()
+ self.logger = self.setup_logger()
+ # 历史消息id暂存,用于幂等控制
+ self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
+ logger.info("[dingtalk] client_id={}, client_secret={} ".format(
+ self.dingtalk_client_id, self.dingtalk_client_secret))
+ # 无需群校验和前缀
+ conf()["group_name_white_list"] = ["ALL_GROUP"]
+
+ def startup(self):
+ credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
+ client = dingtalk_stream.DingTalkStreamClient(credential)
+ client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
+ client.start_forever()
+
+ def handle_single(self, cmsg: DingTalkMessage):
+ # 处理单聊消息
+ if cmsg.ctype == ContextType.VOICE:
+ logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.IMAGE:
+ logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.PATPAT:
+ logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.TEXT:
+ expression = cmsg.my_msg
+ cmsg.content = conf()["single_chat_prefix"][0] + cmsg.content
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
+ if context:
+ self.produce(context)
+
+ def handle_group(self, cmsg: DingTalkMessage):
+ # 处理群聊消息
+ if cmsg.ctype == ContextType.VOICE:
+ logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.IMAGE:
+ logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.PATPAT:
+ logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.TEXT:
+ expression = cmsg.my_msg
+ cmsg.content = conf()["group_chat_prefix"][0] + cmsg.content
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
+ context['no_need_at'] = True
+ if context:
+ self.produce(context)
+
+ async def process(self, callback: dingtalk_stream.CallbackMessage):
+ try:
+ incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
+ dingtalk_msg = DingTalkMessage(incoming_message)
+ if incoming_message.conversation_type == '1':
+ self.handle_single(dingtalk_msg)
+ else:
+ self.handle_group(dingtalk_msg)
+ return AckMessage.STATUS_OK, 'OK'
+ except Exception as e:
+ logger.error(e)
+ return self.FAILED_MSG
+
+ def send(self, reply: Reply, context: Context):
+ incoming_message = context.kwargs['msg'].incoming_message
+ self.reply_text(reply.content, incoming_message)
diff --git a/channel/dingtalk/dingtalk_message.py b/channel/dingtalk/dingtalk_message.py
new file mode 100644
index 0000000..ef9dc96
--- /dev/null
+++ b/channel/dingtalk/dingtalk_message.py
@@ -0,0 +1,44 @@
+from bridge.context import ContextType
+from channel.chat_message import ChatMessage
+import json
+import requests
+from common.log import logger
+from common.tmp_dir import TmpDir
+from common import utils
+from dingtalk_stream import ChatbotMessage
+
+class DingTalkMessage(ChatMessage):
+ def __init__(self, event: ChatbotMessage):
+ super().__init__(event)
+
+ self.msg_id = event.message_id
+ msg_type = event.message_type
+ self.incoming_message =event
+ self.sender_staff_id = event.sender_staff_id
+ self.other_user_id = event.conversation_id
+ self.create_time = event.create_at
+ if event.conversation_type=="1":
+ self.is_group = False
+ else:
+ self.is_group = True
+
+
+ if msg_type == "text":
+ self.ctype = ContextType.TEXT
+
+ self.content = event.text.content.strip()
+ elif msg_type == "audio":
+
+ # 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理
+ self.content = event.extensions['content']['recognition'].strip()
+ self.ctype = ContextType.TEXT
+ self.from_user_id = event.sender_id
+ self.to_user_id = event.chatbot_user_id
+ self.other_user_nickname = event.conversation_title
+
+ user_id = event.sender_id
+ nickname =event.sender_nick
+
+
+
+
diff --git a/channel/feishu/feishu_channel.py b/channel/feishu/feishu_channel.py
new file mode 100644
index 0000000..76fbbf1
--- /dev/null
+++ b/channel/feishu/feishu_channel.py
@@ -0,0 +1,254 @@
+"""
+飞书通道接入
+
+@author Saboteur7
+@Date 2023/11/19
+"""
+
+# -*- coding=utf-8 -*-
+import uuid
+
+import requests
+import web
+from channel.feishu.feishu_message import FeishuMessage
+from bridge.context import Context
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from common.singleton import singleton
+from config import conf
+from common.expired_dict import ExpiredDict
+from bridge.context import ContextType
+from channel.chat_channel import ChatChannel, check_prefix
+from common import utils
+import json
+import os
+
+URL_VERIFICATION = "url_verification"
+
+
+@singleton
+class FeiShuChanel(ChatChannel):
+ feishu_app_id = conf().get('feishu_app_id')
+ feishu_app_secret = conf().get('feishu_app_secret')
+ feishu_token = conf().get('feishu_token')
+
+ def __init__(self):
+ super().__init__()
+ # 历史消息id暂存,用于幂等控制
+ self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
+ logger.info("[FeiShu] app_id={}, app_secret={} verification_token={}".format(
+ self.feishu_app_id, self.feishu_app_secret, self.feishu_token))
+ # 无需群校验和前缀
+ conf()["group_name_white_list"] = ["ALL_GROUP"]
+ conf()["single_chat_prefix"] = []
+
+ def startup(self):
+ urls = (
+ '/', 'channel.feishu.feishu_channel.FeishuController'
+ )
+ app = web.application(urls, globals(), autoreload=False)
+ port = conf().get("feishu_port", 9891)
+ web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
+
+ def send(self, reply: Reply, context: Context):
+ msg = context.get("msg")
+ is_group = context["isgroup"]
+ if msg:
+ access_token = msg.access_token
+ else:
+ access_token = self.fetch_access_token()
+ headers = {
+ "Authorization": "Bearer " + access_token,
+ "Content-Type": "application/json",
+ }
+ msg_type = "text"
+ logger.info(f"[FeiShu] start send reply message, type={context.type}, content={reply.content}")
+ reply_content = reply.content
+ content_key = "text"
+ if reply.type == ReplyType.IMAGE_URL:
+ # 图片上传
+ reply_content = self._upload_image_url(reply.content, access_token)
+ if not reply_content:
+ logger.warning("[FeiShu] upload file failed")
+ return
+ msg_type = "image"
+ content_key = "image_key"
+ if is_group:
+ # 群聊中直接回复
+ url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply"
+ data = {
+ "msg_type": msg_type,
+ "content": json.dumps({content_key: reply_content})
+ }
+ res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
+ else:
+ url = "https://open.feishu.cn/open-apis/im/v1/messages"
+ params = {"receive_id_type": context.get("receive_id_type") or "open_id"}
+ data = {
+ "receive_id": context.get("receiver"),
+ "msg_type": msg_type,
+ "content": json.dumps({content_key: reply_content})
+ }
+ res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10))
+ res = res.json()
+ if res.get("code") == 0:
+ logger.info(f"[FeiShu] send message success")
+ else:
+ logger.error(f"[FeiShu] send message failed, code={res.get('code')}, msg={res.get('msg')}")
+
+
+ def fetch_access_token(self) -> str:
+ url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/"
+ headers = {
+ "Content-Type": "application/json"
+ }
+ req_body = {
+ "app_id": self.feishu_app_id,
+ "app_secret": self.feishu_app_secret
+ }
+ data = bytes(json.dumps(req_body), encoding='utf8')
+ response = requests.post(url=url, data=data, headers=headers)
+ if response.status_code == 200:
+ res = response.json()
+ if res.get("code") != 0:
+ logger.error(f"[FeiShu] get tenant_access_token error, code={res.get('code')}, msg={res.get('msg')}")
+ return ""
+ else:
+ return res.get("tenant_access_token")
+ else:
+ logger.error(f"[FeiShu] fetch token error, res={response}")
+
+
+ def _upload_image_url(self, img_url, access_token):
+ logger.debug(f"[WX] start download image, img_url={img_url}")
+ response = requests.get(img_url)
+ suffix = utils.get_path_suffix(img_url)
+ temp_name = str(uuid.uuid4()) + "." + suffix
+ if response.status_code == 200:
+ # 将图片内容保存为临时文件
+ with open(temp_name, "wb") as file:
+ file.write(response.content)
+
+ # upload
+ upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
+ data = {
+ 'image_type': 'message'
+ }
+ headers = {
+ 'Authorization': f'Bearer {access_token}',
+ }
+ with open(temp_name, "rb") as file:
+ upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
+ logger.info(f"[FeiShu] upload file, res={upload_response.content}")
+ os.remove(temp_name)
+ return upload_response.json().get("data").get("image_key")
+
+
+
+class FeishuController:
+ # 类常量
+ FAILED_MSG = '{"success": false}'
+ SUCCESS_MSG = '{"success": true}'
+ MESSAGE_RECEIVE_TYPE = "im.message.receive_v1"
+
+ def GET(self):
+ return "Feishu service start success!"
+
+ def POST(self):
+ try:
+ channel = FeiShuChanel()
+
+ request = json.loads(web.data().decode("utf-8"))
+ logger.debug(f"[FeiShu] receive request: {request}")
+
+ # 1.事件订阅回调验证
+ if request.get("type") == URL_VERIFICATION:
+ varify_res = {"challenge": request.get("challenge")}
+ return json.dumps(varify_res)
+
+ # 2.消息接收处理
+ # token 校验
+ header = request.get("header")
+ if not header or header.get("token") != channel.feishu_token:
+ return self.FAILED_MSG
+
+ # 处理消息事件
+ event = request.get("event")
+ if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event:
+ if not event.get("message") or not event.get("sender"):
+ logger.warning(f"[FeiShu] invalid message, msg={request}")
+ return self.FAILED_MSG
+ msg = event.get("message")
+
+ # 幂等判断
+ if channel.receivedMsgs.get(msg.get("message_id")):
+ logger.warning(f"[FeiShu] repeat msg filtered, event_id={header.get('event_id')}")
+ return self.SUCCESS_MSG
+ channel.receivedMsgs[msg.get("message_id")] = True
+
+ is_group = False
+ chat_type = msg.get("chat_type")
+ if chat_type == "group":
+ if not msg.get("mentions") and msg.get("message_type") == "text":
+ # 群聊中未@不响应
+ return self.SUCCESS_MSG
+ if msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get("message_type") == "text":
+ # 不是@机器人,不响应
+ return self.SUCCESS_MSG
+ # 群聊
+ is_group = True
+ receive_id_type = "chat_id"
+ elif chat_type == "p2p":
+ receive_id_type = "open_id"
+ else:
+ logger.warning("[FeiShu] message ignore")
+ return self.SUCCESS_MSG
+ # 构造飞书消息对象
+ feishu_msg = FeishuMessage(event, is_group=is_group, access_token=channel.fetch_access_token())
+ if not feishu_msg:
+ return self.SUCCESS_MSG
+
+ context = self._compose_context(
+ feishu_msg.ctype,
+ feishu_msg.content,
+ isgroup=is_group,
+ msg=feishu_msg,
+ receive_id_type=receive_id_type,
+ no_need_at=True
+ )
+ if context:
+ channel.produce(context)
+ logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
+ return self.SUCCESS_MSG
+
+ except Exception as e:
+ logger.error(e)
+ return self.FAILED_MSG
+
+ def _compose_context(self, ctype: ContextType, content, **kwargs):
+ context = Context(ctype, content)
+ context.kwargs = kwargs
+ if "origin_ctype" not in context:
+ context["origin_ctype"] = ctype
+
+ cmsg = context["msg"]
+ context["session_id"] = cmsg.from_user_id
+ context["receiver"] = cmsg.other_user_id
+
+ if ctype == ContextType.TEXT:
+ # 1.文本请求
+ # 图片生成处理
+ img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
+ if img_match_prefix:
+ content = content.replace(img_match_prefix, "", 1)
+ context.type = ContextType.IMAGE_CREATE
+ else:
+ context.type = ContextType.TEXT
+ context.content = content.strip()
+
+ elif context.type == ContextType.VOICE:
+ # 2.语音请求
+ if "desire_rtype" not in context and conf().get("voice_reply_voice"):
+ context["desire_rtype"] = ReplyType.VOICE
+
+ return context
diff --git a/channel/feishu/feishu_message.py b/channel/feishu/feishu_message.py
new file mode 100644
index 0000000..e2054c1
--- /dev/null
+++ b/channel/feishu/feishu_message.py
@@ -0,0 +1,63 @@
+from bridge.context import ContextType
+from channel.chat_message import ChatMessage
+import json
+import requests
+from common.log import logger
+from common.tmp_dir import TmpDir
+from common import utils
+
+
+class FeishuMessage(ChatMessage):
+ def __init__(self, event: dict, is_group=False, access_token=None):
+ super().__init__(event)
+ msg = event.get("message")
+ sender = event.get("sender")
+ self.access_token = access_token
+ self.msg_id = msg.get("message_id")
+ self.create_time = msg.get("create_time")
+ self.is_group = is_group
+ msg_type = msg.get("message_type")
+
+ if msg_type == "text":
+ self.ctype = ContextType.TEXT
+ content = json.loads(msg.get('content'))
+ self.content = content.get("text").strip()
+ elif msg_type == "file":
+ self.ctype = ContextType.FILE
+ content = json.loads(msg.get("content"))
+ file_key = content.get("file_key")
+ file_name = content.get("file_name")
+
+ self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name)
+
+ def _download_file():
+ # 如果响应状态码是200,则将响应内容写入本地文件
+ url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}"
+ headers = {
+ "Authorization": "Bearer " + access_token,
+ }
+ params = {
+ "type": "file"
+ }
+ response = requests.get(url=url, headers=headers, params=params)
+ if response.status_code == 200:
+ with open(self.content, "wb") as f:
+ f.write(response.content)
+ else:
+ logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
+ self._prepare_fn = _download_file
+ else:
+ raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))
+
+ self.from_user_id = sender.get("sender_id").get("open_id")
+ self.to_user_id = event.get("app_id")
+ if is_group:
+ # 群聊
+ self.other_user_id = msg.get("chat_id")
+ self.actual_user_id = self.from_user_id
+ self.content = self.content.replace("@_user_1", "").strip()
+ self.actual_user_nickname = ""
+ else:
+ # 私聊
+ self.other_user_id = self.from_user_id
+ self.actual_user_id = self.from_user_id
diff --git a/channel/terminal/terminal_channel.py b/channel/terminal/terminal_channel.py
new file mode 100644
index 0000000..9a413dc
--- /dev/null
+++ b/channel/terminal/terminal_channel.py
@@ -0,0 +1,92 @@
+import sys
+
+from bridge.context import *
+from bridge.reply import Reply, ReplyType
+from channel.chat_channel import ChatChannel, check_prefix
+from channel.chat_message import ChatMessage
+from common.log import logger
+from config import conf
+
+
+class TerminalMessage(ChatMessage):
+ def __init__(
+ self,
+ msg_id,
+ content,
+ ctype=ContextType.TEXT,
+ from_user_id="User",
+ to_user_id="Chatgpt",
+ other_user_id="Chatgpt",
+ ):
+ self.msg_id = msg_id
+ self.ctype = ctype
+ self.content = content
+ self.from_user_id = from_user_id
+ self.to_user_id = to_user_id
+ self.other_user_id = other_user_id
+
+
+class TerminalChannel(ChatChannel):
+ NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
+
+ def send(self, reply: Reply, context: Context):
+ print("\nBot:")
+ if reply.type == ReplyType.IMAGE:
+ from PIL import Image
+
+ image_storage = reply.content
+ image_storage.seek(0)
+ img = Image.open(image_storage)
+ print("")
+ img.show()
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+ import io
+
+ import requests
+ from PIL import Image
+
+ img_url = reply.content
+ pic_res = requests.get(img_url, stream=True)
+ image_storage = io.BytesIO()
+ for block in pic_res.iter_content(1024):
+ image_storage.write(block)
+ image_storage.seek(0)
+ img = Image.open(image_storage)
+ print(img_url)
+ img.show()
+ else:
+ print(reply.content)
+ print("\nUser:", end="")
+ sys.stdout.flush()
+ return
+
+ def startup(self):
+ context = Context()
+ logger.setLevel("WARN")
+ print("\nPlease input your question:\nUser:", end="")
+ sys.stdout.flush()
+ msg_id = 0
+ while True:
+ try:
+ prompt = self.get_input()
+ except KeyboardInterrupt:
+ print("\nExiting...")
+ sys.exit()
+ msg_id += 1
+ trigger_prefixs = conf().get("single_chat_prefix", [""])
+ if check_prefix(prompt, trigger_prefixs) is None:
+ prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
+
+ context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
+ if context:
+ self.produce(context)
+ else:
+ raise Exception("context is None")
+
+ def get_input(self):
+ """
+ Multi-line input function
+ """
+ sys.stdout.flush()
+ line = input()
+ return line
diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py
new file mode 100644
index 0000000..717b068
--- /dev/null
+++ b/channel/wechat/wechat_channel.py
@@ -0,0 +1,283 @@
+# encoding:utf-8
+
+"""
+wechat channel
+"""
+
+import io
+import json
+import os
+import threading
+import time
+
+import requests
+
+from bridge.context import *
+from bridge.reply import *
+from channel.chat_channel import ChatChannel
+from channel import chat_channel
+from channel.wechat.wechat_message import *
+from common.expired_dict import ExpiredDict
+from common.log import logger
+from common.singleton import singleton
+from common.time_check import time_checker
+from config import conf, get_appdata_dir
+from lib import itchat
+from lib.itchat.content import *
+
+
+@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING])
+def handler_single_msg(msg):
+ try:
+ cmsg = WechatMessage(msg, False)
+ except NotImplementedError as e:
+ logger.debug("[WX]single message {} skipped: {}".format(msg["MsgId"], e))
+ return None
+ WechatChannel().handle_single(cmsg)
+ return None
+
+
+@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING], isGroupChat=True)
+def handler_group_msg(msg):
+ try:
+ cmsg = WechatMessage(msg, True)
+ except NotImplementedError as e:
+ logger.debug("[WX]group message {} skipped: {}".format(msg["MsgId"], e))
+ return None
+ WechatChannel().handle_group(cmsg)
+ return None
+
+
+def _check(func):
+ def wrapper(self, cmsg: ChatMessage):
+ msgId = cmsg.msg_id
+ if msgId in self.receivedMsgs:
+ logger.info("Wechat message {} already received, ignore".format(msgId))
+ return
+ self.receivedMsgs[msgId] = True
+ create_time = cmsg.create_time # 消息时间戳
+ if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
+ logger.debug("[WX]history message {} skipped".format(msgId))
+ return
+ if cmsg.my_msg and not cmsg.is_group:
+ logger.debug("[WX]my message {} skipped".format(msgId))
+ return
+ return func(self, cmsg)
+
+ return wrapper
+
+
+# 可用的二维码生成接口
+# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
+# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
+def qrCallback(uuid, status, qrcode):
+ # logger.debug("qrCallback: {} {}".format(uuid,status))
+ if status == "0":
+ try:
+ from PIL import Image
+
+ img = Image.open(io.BytesIO(qrcode))
+ _thread = threading.Thread(target=img.show, args=("QRCode",))
+ _thread.setDaemon(True)
+ _thread.start()
+ except Exception as e:
+ pass
+
+ import qrcode
+
+ url = f"https://login.weixin.qq.com/l/{uuid}"
+
+ qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
+ qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
+ qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
+ qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
+ print("You can also scan QRCode in any website below:")
+ print(qr_api3)
+ print(qr_api4)
+ print(qr_api2)
+ print(qr_api1)
+ _send_qr_code([qr_api1, qr_api2, qr_api3, qr_api4])
+ qr = qrcode.QRCode(border=1)
+ qr.add_data(url)
+ qr.make(fit=True)
+ qr.print_ascii(invert=True)
+
+
+@singleton
+class WechatChannel(ChatChannel):
+ NOT_SUPPORT_REPLYTYPE = []
+
+ def __init__(self):
+ super().__init__()
+ self.receivedMsgs = ExpiredDict(60 * 60)
+ self.auto_login_times = 0
+
+ def startup(self):
+ try:
+ itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
+ # login by scan QRCode
+ hotReload = conf().get("hot_reload", False)
+ status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
+ itchat.auto_login(
+ enableCmdQR=2,
+ hotReload=hotReload,
+ statusStorageDir=status_path,
+ qrCallback=qrCallback,
+ exitCallback=self.exitCallback,
+ loginCallback=self.loginCallback
+ )
+ self.user_id = itchat.instance.storageClass.userName
+ self.name = itchat.instance.storageClass.nickName
+ logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
+ # start message listener
+ itchat.run()
+ except Exception as e:
+ logger.error(e)
+
+ def exitCallback(self):
+ try:
+ from common.linkai_client import chat_client
+ if chat_client.client_id and conf().get("use_linkai"):
+ _send_logout()
+ time.sleep(2)
+ self.auto_login_times += 1
+ if self.auto_login_times < 100:
+ chat_channel.handler_pool._shutdown = False
+ self.startup()
+ except Exception as e:
+ pass
+
+ def loginCallback(self):
+ logger.debug("Login success")
+ _send_login_success()
+
+ # handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复
+ # Context包含了消息的所有信息,包括以下属性
+ # type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
+ # content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
+ # kwargs 附加参数字典,包含以下的key:
+ # session_id: 会话id
+ # isgroup: 是否是群聊
+ # receiver: 需要回复的对象
+ # msg: ChatMessage消息对象
+ # origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
+ # desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复
+ @time_checker
+ @_check
+ def handle_single(self, cmsg: ChatMessage):
+ # filter system message
+ if cmsg.other_user_id in ["weixin"]:
+ return
+ if cmsg.ctype == ContextType.VOICE:
+ if conf().get("speech_recognition") != True:
+ return
+ logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.IMAGE:
+ logger.debug("[WX]receive image msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.PATPAT:
+ logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.TEXT:
+ logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
+ else:
+ logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
+ if context:
+ self.produce(context)
+
+ @time_checker
+ @_check
+ def handle_group(self, cmsg: ChatMessage):
+ if cmsg.ctype == ContextType.VOICE:
+ if conf().get("group_speech_recognition") != True:
+ return
+ logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.IMAGE:
+ logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
+ elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.ACCEPT_FRIEND, ContextType.EXIT_GROUP]:
+ logger.debug("[WX]receive note msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.TEXT:
+ # logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
+ pass
+ elif cmsg.ctype == ContextType.FILE:
+ logger.debug(f"[WX]receive attachment msg, file_name={cmsg.content}")
+ else:
+ logger.debug("[WX]receive group msg: {}".format(cmsg.content))
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
+ if context:
+ self.produce(context)
+
+ # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
+ def send(self, reply: Reply, context: Context):
+ receiver = context["receiver"]
+ if reply.type == ReplyType.TEXT:
+ itchat.send(reply.content, toUserName=receiver)
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
+ elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
+ itchat.send(reply.content, toUserName=receiver)
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
+ elif reply.type == ReplyType.VOICE:
+ itchat.send_file(reply.content, toUserName=receiver)
+ logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+ img_url = reply.content
+ logger.debug(f"[WX] start download image, img_url={img_url}")
+ pic_res = requests.get(img_url, stream=True)
+ image_storage = io.BytesIO()
+ size = 0
+ for block in pic_res.iter_content(1024):
+ size += len(block)
+ image_storage.write(block)
+ logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
+ image_storage.seek(0)
+ itchat.send_image(image_storage, toUserName=receiver)
+ logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
+ image_storage = reply.content
+ image_storage.seek(0)
+ itchat.send_image(image_storage, toUserName=receiver)
+ logger.info("[WX] sendImage, receiver={}".format(receiver))
+ elif reply.type == ReplyType.FILE: # 新增文件回复类型
+ file_storage = reply.content
+ itchat.send_file(file_storage, toUserName=receiver)
+ logger.info("[WX] sendFile, receiver={}".format(receiver))
+ elif reply.type == ReplyType.VIDEO: # 新增视频回复类型
+ video_storage = reply.content
+ itchat.send_video(video_storage, toUserName=receiver)
+ logger.info("[WX] sendFile, receiver={}".format(receiver))
+ elif reply.type == ReplyType.VIDEO_URL: # 新增视频URL回复类型
+ video_url = reply.content
+ logger.debug(f"[WX] start download video, video_url={video_url}")
+ video_res = requests.get(video_url, stream=True)
+ video_storage = io.BytesIO()
+ size = 0
+ for block in video_res.iter_content(1024):
+ size += len(block)
+ video_storage.write(block)
+ logger.info(f"[WX] download video success, size={size}, video_url={video_url}")
+ video_storage.seek(0)
+ itchat.send_video(video_storage, toUserName=receiver)
+ logger.info("[WX] sendVideo url={}, receiver={}".format(video_url, receiver))
+
+def _send_login_success():
+ try:
+ from common.linkai_client import chat_client
+ if chat_client.client_id:
+ chat_client.send_login_success()
+ except Exception as e:
+ pass
+
+def _send_logout():
+ try:
+ from common.linkai_client import chat_client
+ if chat_client.client_id:
+ chat_client.send_logout()
+ except Exception as e:
+ pass
+
+def _send_qr_code(qrcode_list: list):
+ try:
+ from common.linkai_client import chat_client
+ if chat_client.client_id:
+ chat_client.send_qrcode(qrcode_list)
+ except Exception as e:
+ pass
diff --git a/channel/wechat/wechat_message.py b/channel/wechat/wechat_message.py
new file mode 100644
index 0000000..b8b1d91
--- /dev/null
+++ b/channel/wechat/wechat_message.py
@@ -0,0 +1,102 @@
+import re
+
+from bridge.context import ContextType
+from channel.chat_message import ChatMessage
+from common.log import logger
+from common.tmp_dir import TmpDir
+from lib import itchat
+from lib.itchat.content import *
+
+class WechatMessage(ChatMessage):
+ def __init__(self, itchat_msg, is_group=False):
+ super().__init__(itchat_msg)
+ self.msg_id = itchat_msg["MsgId"]
+ self.create_time = itchat_msg["CreateTime"]
+ self.is_group = is_group
+
+ if itchat_msg["Type"] == TEXT:
+ self.ctype = ContextType.TEXT
+ self.content = itchat_msg["Text"]
+ elif itchat_msg["Type"] == VOICE:
+ self.ctype = ContextType.VOICE
+ self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
+ self._prepare_fn = lambda: itchat_msg.download(self.content)
+ elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
+ self.ctype = ContextType.IMAGE
+ self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
+ self._prepare_fn = lambda: itchat_msg.download(self.content)
+ elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
+ if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
+ # 这里只能得到nickname, actual_user_id还是机器人的id
+ if "加入了群聊" in itchat_msg["Content"]:
+ self.ctype = ContextType.JOIN_GROUP
+ self.content = itchat_msg["Content"]
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
+ elif "加入群聊" in itchat_msg["Content"]:
+ self.ctype = ContextType.JOIN_GROUP
+ self.content = itchat_msg["Content"]
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
+
+ elif is_group and ("移出了群聊" in itchat_msg["Content"]):
+ self.ctype = ContextType.EXIT_GROUP
+ self.content = itchat_msg["Content"]
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
+
+ elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
+ self.ctype = ContextType.ACCEPT_FRIEND
+ self.content = itchat_msg["Content"]
+ elif "拍了拍我" in itchat_msg["Content"]:
+ self.ctype = ContextType.PATPAT
+ self.content = itchat_msg["Content"]
+ if is_group:
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
+ else:
+ raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
+ elif itchat_msg["Type"] == ATTACHMENT:
+ self.ctype = ContextType.FILE
+ self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
+ self._prepare_fn = lambda: itchat_msg.download(self.content)
+ elif itchat_msg["Type"] == SHARING:
+ self.ctype = ContextType.SHARING
+ self.content = itchat_msg.get("Url")
+
+ else:
+ raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
+
+ self.from_user_id = itchat_msg["FromUserName"]
+ self.to_user_id = itchat_msg["ToUserName"]
+
+ user_id = itchat.instance.storageClass.userName
+ nickname = itchat.instance.storageClass.nickName
+
+ # 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
+ # 以下很繁琐,一句话总结:能填的都填了。
+ if self.from_user_id == user_id:
+ self.from_user_nickname = nickname
+ if self.to_user_id == user_id:
+ self.to_user_nickname = nickname
+ try: # 陌生人时候, User字段可能不存在
+ # my_msg 为True是表示是自己发送的消息
+ self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
+ itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
+ self.other_user_id = itchat_msg["User"]["UserName"]
+ self.other_user_nickname = itchat_msg["User"]["NickName"]
+ if self.other_user_id == self.from_user_id:
+ self.from_user_nickname = self.other_user_nickname
+ if self.other_user_id == self.to_user_id:
+ self.to_user_nickname = self.other_user_nickname
+ if itchat_msg["User"].get("Self"):
+ # 自身的展示名,当设置了群昵称时,该字段表示群昵称
+ self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
+ except KeyError as e: # 处理偶尔没有对方信息的情况
+ logger.warn("[WX]get other_user_id failed: " + str(e))
+ if self.from_user_id == user_id:
+ self.other_user_id = self.to_user_id
+ else:
+ self.other_user_id = self.from_user_id
+
+ if self.is_group:
+ self.is_at = itchat_msg["IsAt"]
+ self.actual_user_id = itchat_msg["ActualUserName"]
+ if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.EXIT_GROUP]:
+ self.actual_user_nickname = itchat_msg["ActualNickName"]
diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py
new file mode 100644
index 0000000..051a9cf
--- /dev/null
+++ b/channel/wechat/wechaty_channel.py
@@ -0,0 +1,129 @@
+# encoding:utf-8
+
+"""
+wechaty channel
+Python Wechaty - https://github.com/wechaty/python-wechaty
+"""
+import asyncio
+import base64
+import os
+import time
+
+from wechaty import Contact, Wechaty
+from wechaty.user import Message
+from wechaty_puppet import FileBox
+
+from bridge.context import *
+from bridge.context import Context
+from bridge.reply import *
+from channel.chat_channel import ChatChannel
+from channel.wechat.wechaty_message import WechatyMessage
+from common.log import logger
+from common.singleton import singleton
+from config import conf
+
+try:
+ from voice.audio_convert import any_to_sil
+except Exception as e:
+ pass
+
+
+@singleton
+class WechatyChannel(ChatChannel):
+ NOT_SUPPORT_REPLYTYPE = []
+
+ def __init__(self):
+ super().__init__()
+
+ def startup(self):
+ config = conf()
+ token = config.get("wechaty_puppet_service_token")
+ os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
+ asyncio.run(self.main())
+
+ async def main(self):
+ loop = asyncio.get_event_loop()
+ # 将asyncio的loop传入处理线程
+ self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
+ self.bot = Wechaty()
+ self.bot.on("login", self.on_login)
+ self.bot.on("message", self.on_message)
+ await self.bot.start()
+
+ async def on_login(self, contact: Contact):
+ self.user_id = contact.contact_id
+ self.name = contact.name
+ logger.info("[WX] login user={}".format(contact))
+
+ # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
+ def send(self, reply: Reply, context: Context):
+ receiver_id = context["receiver"]
+ loop = asyncio.get_event_loop()
+ if context["isgroup"]:
+ receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
+ else:
+ receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
+ msg = None
+ if reply.type == ReplyType.TEXT:
+ msg = reply.content
+ asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
+ elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
+ msg = reply.content
+ asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
+ elif reply.type == ReplyType.VOICE:
+ voiceLength = None
+ file_path = reply.content
+ sil_file = os.path.splitext(file_path)[0] + ".sil"
+ voiceLength = int(any_to_sil(file_path, sil_file))
+ if voiceLength >= 60000:
+ voiceLength = 60000
+ logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
+ # 发送语音
+ t = int(time.time())
+ msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
+ if voiceLength is not None:
+ msg.metadata["voiceLength"] = voiceLength
+ asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
+ try:
+ os.remove(file_path)
+ if sil_file != file_path:
+ os.remove(sil_file)
+ except Exception as e:
+ pass
+ logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+ img_url = reply.content
+ t = int(time.time())
+ msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
+ asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
+ logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
+ image_storage = reply.content
+ image_storage.seek(0)
+ t = int(time.time())
+ msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
+ asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
+ logger.info("[WX] sendImage, receiver={}".format(receiver))
+
+ async def on_message(self, msg: Message):
+ """
+ listen for message event
+ """
+ try:
+ cmsg = await WechatyMessage(msg)
+ except NotImplementedError as e:
+ logger.debug("[WX] {}".format(e))
+ return
+ except Exception as e:
+ logger.exception("[WX] {}".format(e))
+ return
+ logger.debug("[WX] message:{}".format(cmsg))
+ room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
+ isgroup = room is not None
+ ctype = cmsg.ctype
+ context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
+ if context:
+ logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
+ self.produce(context)
diff --git a/channel/wechat/wechaty_message.py b/channel/wechat/wechaty_message.py
new file mode 100644
index 0000000..cdb41dd
--- /dev/null
+++ b/channel/wechat/wechaty_message.py
@@ -0,0 +1,89 @@
+import asyncio
+import re
+
+from wechaty import MessageType
+from wechaty.user import Message
+
+from bridge.context import ContextType
+from channel.chat_message import ChatMessage
+from common.log import logger
+from common.tmp_dir import TmpDir
+
+
+class aobject(object):
+ """Inheriting this class allows you to define an async __init__.
+
+ So you can create objects by doing something like `await MyClass(params)`
+ """
+
+ async def __new__(cls, *a, **kw):
+ instance = super().__new__(cls)
+ await instance.__init__(*a, **kw)
+ return instance
+
+ async def __init__(self):
+ pass
+
+
+class WechatyMessage(ChatMessage, aobject):
+ async def __init__(self, wechaty_msg: Message):
+ super().__init__(wechaty_msg)
+
+ room = wechaty_msg.room()
+
+ self.msg_id = wechaty_msg.message_id
+ self.create_time = wechaty_msg.payload.timestamp
+ self.is_group = room is not None
+
+ if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
+ self.ctype = ContextType.TEXT
+ self.content = wechaty_msg.text()
+ elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
+ self.ctype = ContextType.VOICE
+ voice_file = await wechaty_msg.to_file_box()
+ self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
+
+ def func():
+ loop = asyncio.get_event_loop()
+ asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
+
+ self._prepare_fn = func
+
+ else:
+ raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
+
+ from_contact = wechaty_msg.talker() # 获取消息的发送者
+ self.from_user_id = from_contact.contact_id
+ self.from_user_nickname = from_contact.name
+
+ # group中的from和to,wechaty跟itchat含义不一样
+ # wecahty: from是消息实际发送者, to:所在群
+ # itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
+ # 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
+
+ if self.is_group:
+ self.to_user_id = room.room_id
+ self.to_user_nickname = await room.topic()
+ else:
+ to_contact = wechaty_msg.to()
+ self.to_user_id = to_contact.contact_id
+ self.to_user_nickname = to_contact.name
+
+ if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
+ self.other_user_id = self.to_user_id
+ self.other_user_nickname = self.to_user_nickname
+ else:
+ self.other_user_id = self.from_user_id
+ self.other_user_nickname = self.from_user_nickname
+
+ if self.is_group: # wechaty群聊中,实际发送用户就是from_user
+ self.is_at = await wechaty_msg.mention_self()
+ if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
+ name = wechaty_msg.wechaty.user_self().name
+ pattern = f"@{re.escape(name)}(\u2005|\u0020)"
+ if re.search(pattern, self.content):
+ logger.debug(f"wechaty message {self.msg_id} include at")
+ self.is_at = True
+
+ self.actual_user_id = self.from_user_id
+ self.actual_user_nickname = self.from_user_nickname
diff --git a/channel/wechatcom/README.md b/channel/wechatcom/README.md
new file mode 100644
index 0000000..2f54a79
--- /dev/null
+++ b/channel/wechatcom/README.md
@@ -0,0 +1,85 @@
+# 企业微信应用号channel
+
+企业微信官方提供了客服、应用等API,本channel使用的是企业微信的自建应用API的能力。
+
+因为未来可能还会开发客服能力,所以本channel的类型名叫作`wechatcom_app`。
+
+`wechatcom_app` channel支持插件系统和图片声音交互等能力,除了无法加入群聊,作为个人使用的私人助理已绰绰有余。
+
+## 开始之前
+
+- 在企业中确认自己拥有在企业内自建应用的权限。
+- 如果没有权限或者是个人用户,也可创建未认证的企业。操作方式:登录手机企业微信,选择`创建/加入企业`来创建企业,类型请选择企业,企业名称可随意填写。
+ 未认证的企业有100人的服务人数上限,其他功能与认证企业没有差异。
+
+本channel需安装的依赖与公众号一致,需要安装`wechatpy`和`web.py`,它们包含在`requirements-optional.txt`中。
+
+此外,如果你是`Linux`系统,除了`ffmpeg`还需要安装`amr`编码器,否则会出现找不到编码器的错误,无法正常使用语音功能。
+
+- Ubuntu/Debian
+
+```bash
+apt-get install libavcodec-extra
+```
+
+- Alpine
+
+需自行编译`ffmpeg`,在编译参数里加入`amr`编码器的支持
+
+## 使用方法
+
+1.查看企业ID
+
+- 扫码登陆[企业微信后台](https://work.weixin.qq.com)
+- 选择`我的企业`,点击`企业信息`,记住该`企业ID`
+
+2.创建自建应用
+
+- 选择应用管理, 在自建区选创建应用来创建企业自建应用
+- 上传应用logo,填写应用名称等项
+- 创建应用后进入应用详情页面,记住`AgentId`和`Secert`
+
+3.配置应用
+
+- 在详情页点击`企业可信IP`的配置(没看到可以不管),填入你服务器的公网IP,如果不知道可以先不填
+- 点击`接收消息`下的启用API接收消息
+- `URL`填写格式为`http://url:port/wxcomapp`,`port`是程序监听的端口,默认是9898
+ 如果是未认证的企业,url可直接使用服务器的IP。如果是认证企业,需要使用备案的域名,可使用二级域名。
+- `Token`可随意填写,停留在这个页面
+- 在程序根目录`config.json`中增加配置(**去掉注释**),`wechatcomapp_aes_key`是当前页面的`wechatcomapp_aes_key`
+
+```python
+ "channel_type": "wechatcom_app",
+ "wechatcom_corp_id": "", # 企业微信公司的corpID
+ "wechatcomapp_token": "", # 企业微信app的token
+ "wechatcomapp_port": 9898, # 企业微信app的服务端口, 不需要端口转发
+ "wechatcomapp_secret": "", # 企业微信app的secret
+ "wechatcomapp_agent_id": "", # 企业微信app的agent_id
+ "wechatcomapp_aes_key": "", # 企业微信app的aes_key
+```
+
+- 运行程序,在页面中点击保存,保存成功说明验证成功
+
+4.连接个人微信
+
+选择`我的企业`,点击`微信插件`,下面有个邀请关注的二维码。微信扫码后,即可在微信中看到对应企业,在这里你便可以和机器人沟通。
+
+向机器人发送消息,如果日志里出现报错:
+
+```bash
+Error code: 60020, message: "not allow to access from your ip, ...from ip: xx.xx.xx.xx"
+```
+
+意思是IP不可信,需要参考上一步的`企业可信IP`配置,把这里的IP加进去。
+
+~~### Railway部署方式~~(2023-06-08已失效)
+
+~~公众号不能在`Railway`上部署,但企业微信应用[可以](https://railway.app/template/-FHS--?referralCode=RC3znh)!~~
+
+~~填写配置后,将部署完成后的网址```**.railway.app/wxcomapp```,填写在上一步的URL中。发送信息后观察日志,把报错的IP加入到可信IP。(每次重启后都需要加入可信IP)~~
+
+## 测试体验
+
+AIGC开放社区中已经部署了多个可免费使用的Bot,扫描下方的二维码会自动邀请你来体验。
+
+
diff --git a/channel/wechatcom/wechatcomapp_channel.py b/channel/wechatcom/wechatcomapp_channel.py
new file mode 100644
index 0000000..1a08596
--- /dev/null
+++ b/channel/wechatcom/wechatcomapp_channel.py
@@ -0,0 +1,178 @@
+# -*- coding=utf-8 -*-
+import io
+import os
+import time
+
+import requests
+import web
+from wechatpy.enterprise import create_reply, parse_message
+from wechatpy.enterprise.crypto import WeChatCrypto
+from wechatpy.enterprise.exceptions import InvalidCorpIdException
+from wechatpy.exceptions import InvalidSignatureException, WeChatClientException
+
+from bridge.context import Context
+from bridge.reply import Reply, ReplyType
+from channel.chat_channel import ChatChannel
+from channel.wechatcom.wechatcomapp_client import WechatComAppClient
+from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
+from common.log import logger
+from common.singleton import singleton
+from common.utils import compress_imgfile, fsize, split_string_by_utf8_length
+from config import conf, subscribe_msg
+from voice.audio_convert import any_to_amr, split_audio
+
+MAX_UTF8_LEN = 2048
+
+
+@singleton
+class WechatComAppChannel(ChatChannel):
+ NOT_SUPPORT_REPLYTYPE = []
+
+ def __init__(self):
+ super().__init__()
+ self.corp_id = conf().get("wechatcom_corp_id")
+ self.secret = conf().get("wechatcomapp_secret")
+ self.agent_id = conf().get("wechatcomapp_agent_id")
+ self.token = conf().get("wechatcomapp_token")
+ self.aes_key = conf().get("wechatcomapp_aes_key")
+ print(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
+ logger.info(
+ "[wechatcom] init: corp_id: {}, secret: {}, agent_id: {}, token: {}, aes_key: {}".format(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
+ )
+ self.crypto = WeChatCrypto(self.token, self.aes_key, self.corp_id)
+ self.client = WechatComAppClient(self.corp_id, self.secret)
+
+ def startup(self):
+ # start message listener
+ urls = ("/wxcomapp", "channel.wechatcom.wechatcomapp_channel.Query")
+ app = web.application(urls, globals(), autoreload=False)
+ port = conf().get("wechatcomapp_port", 9898)
+ web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
+
+ def send(self, reply: Reply, context: Context):
+ receiver = context["receiver"]
+ if reply.type in [ReplyType.TEXT, ReplyType.ERROR, ReplyType.INFO]:
+ reply_text = reply.content
+ texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
+ if len(texts) > 1:
+ logger.info("[wechatcom] text too long, split into {} parts".format(len(texts)))
+ for i, text in enumerate(texts):
+ self.client.message.send_text(self.agent_id, receiver, text)
+ if i != len(texts) - 1:
+ time.sleep(0.5) # 休眠0.5秒,防止发送过快乱序
+ logger.info("[wechatcom] Do send text to {}: {}".format(receiver, reply_text))
+ elif reply.type == ReplyType.VOICE:
+ try:
+ media_ids = []
+ file_path = reply.content
+ amr_file = os.path.splitext(file_path)[0] + ".amr"
+ any_to_amr(file_path, amr_file)
+ duration, files = split_audio(amr_file, 60 * 1000)
+ if len(files) > 1:
+ logger.info("[wechatcom] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
+ for path in files:
+ response = self.client.media.upload("voice", open(path, "rb"))
+ logger.debug("[wechatcom] upload voice response: {}".format(response))
+ media_ids.append(response["media_id"])
+ except WeChatClientException as e:
+ logger.error("[wechatcom] upload voice failed: {}".format(e))
+ return
+ try:
+ os.remove(file_path)
+ if amr_file != file_path:
+ os.remove(amr_file)
+ except Exception:
+ pass
+ for media_id in media_ids:
+ self.client.message.send_voice(self.agent_id, receiver, media_id)
+ time.sleep(1)
+ logger.info("[wechatcom] sendVoice={}, receiver={}".format(reply.content, receiver))
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+ img_url = reply.content
+ pic_res = requests.get(img_url, stream=True)
+ image_storage = io.BytesIO()
+ for block in pic_res.iter_content(1024):
+ image_storage.write(block)
+ sz = fsize(image_storage)
+ if sz >= 10 * 1024 * 1024:
+ logger.info("[wechatcom] image too large, ready to compress, sz={}".format(sz))
+ image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
+ logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
+ image_storage.seek(0)
+ try:
+ response = self.client.media.upload("image", image_storage)
+ logger.debug("[wechatcom] upload image response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatcom] upload image failed: {}".format(e))
+ return
+
+ self.client.message.send_image(self.agent_id, receiver, response["media_id"])
+ logger.info("[wechatcom] sendImage url={}, receiver={}".format(img_url, receiver))
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
+ image_storage = reply.content
+ sz = fsize(image_storage)
+ if sz >= 10 * 1024 * 1024:
+ logger.info("[wechatcom] image too large, ready to compress, sz={}".format(sz))
+ image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
+ logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
+ image_storage.seek(0)
+ try:
+ response = self.client.media.upload("image", image_storage)
+ logger.debug("[wechatcom] upload image response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatcom] upload image failed: {}".format(e))
+ return
+ self.client.message.send_image(self.agent_id, receiver, response["media_id"])
+ logger.info("[wechatcom] sendImage, receiver={}".format(receiver))
+
+
+class Query:
+ def GET(self):
+ channel = WechatComAppChannel()
+ params = web.input()
+ logger.info("[wechatcom] receive params: {}".format(params))
+ try:
+ signature = params.msg_signature
+ timestamp = params.timestamp
+ nonce = params.nonce
+ echostr = params.echostr
+ echostr = channel.crypto.check_signature(signature, timestamp, nonce, echostr)
+ except InvalidSignatureException:
+ raise web.Forbidden()
+ return echostr
+
+ def POST(self):
+ channel = WechatComAppChannel()
+ params = web.input()
+ logger.info("[wechatcom] receive params: {}".format(params))
+ try:
+ signature = params.msg_signature
+ timestamp = params.timestamp
+ nonce = params.nonce
+ message = channel.crypto.decrypt_message(web.data(), signature, timestamp, nonce)
+ except (InvalidSignatureException, InvalidCorpIdException):
+ raise web.Forbidden()
+ msg = parse_message(message)
+ logger.debug("[wechatcom] receive message: {}, msg= {}".format(message, msg))
+ if msg.type == "event":
+ if msg.event == "subscribe":
+ reply_content = subscribe_msg()
+ if reply_content:
+ reply = create_reply(reply_content, msg).render()
+ res = channel.crypto.encrypt_message(reply, nonce, timestamp)
+ return res
+ else:
+ try:
+ wechatcom_msg = WechatComAppMessage(msg, client=channel.client)
+ except NotImplementedError as e:
+ logger.debug("[wechatcom] " + str(e))
+ return "success"
+ context = channel._compose_context(
+ wechatcom_msg.ctype,
+ wechatcom_msg.content,
+ isgroup=False,
+ msg=wechatcom_msg,
+ )
+ if context:
+ channel.produce(context)
+ return "success"
diff --git a/channel/wechatcom/wechatcomapp_client.py b/channel/wechatcom/wechatcomapp_client.py
new file mode 100644
index 0000000..c0feb7a
--- /dev/null
+++ b/channel/wechatcom/wechatcomapp_client.py
@@ -0,0 +1,21 @@
+import threading
+import time
+
+from wechatpy.enterprise import WeChatClient
+
+
+class WechatComAppClient(WeChatClient):
+ def __init__(self, corp_id, secret, access_token=None, session=None, timeout=None, auto_retry=True):
+ super(WechatComAppClient, self).__init__(corp_id, secret, access_token, session, timeout, auto_retry)
+ self.fetch_access_token_lock = threading.Lock()
+
+ def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
+ with self.fetch_access_token_lock:
+ access_token = self.session.get(self.access_token_key)
+ if access_token:
+ if not self.expires_at:
+ return access_token
+ timestamp = time.time()
+ if self.expires_at - timestamp > 60:
+ return access_token
+ return super().fetch_access_token()
diff --git a/channel/wechatcom/wechatcomapp_message.py b/channel/wechatcom/wechatcomapp_message.py
new file mode 100644
index 0000000..a70f755
--- /dev/null
+++ b/channel/wechatcom/wechatcomapp_message.py
@@ -0,0 +1,52 @@
+from wechatpy.enterprise import WeChatClient
+
+from bridge.context import ContextType
+from channel.chat_message import ChatMessage
+from common.log import logger
+from common.tmp_dir import TmpDir
+
+
+class WechatComAppMessage(ChatMessage):
+ def __init__(self, msg, client: WeChatClient, is_group=False):
+ super().__init__(msg)
+ self.msg_id = msg.id
+ self.create_time = msg.time
+ self.is_group = is_group
+
+ if msg.type == "text":
+ self.ctype = ContextType.TEXT
+ self.content = msg.content
+ elif msg.type == "voice":
+ self.ctype = ContextType.VOICE
+ self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
+
+ def download_voice():
+ # 如果响应状态码是200,则将响应内容写入本地文件
+ response = client.media.download(msg.media_id)
+ if response.status_code == 200:
+ with open(self.content, "wb") as f:
+ f.write(response.content)
+ else:
+ logger.info(f"[wechatcom] Failed to download voice file, {response.content}")
+
+ self._prepare_fn = download_voice
+ elif msg.type == "image":
+ self.ctype = ContextType.IMAGE
+ self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
+
+ def download_image():
+ # 如果响应状态码是200,则将响应内容写入本地文件
+ response = client.media.download(msg.media_id)
+ if response.status_code == 200:
+ with open(self.content, "wb") as f:
+ f.write(response.content)
+ else:
+ logger.info(f"[wechatcom] Failed to download image file, {response.content}")
+
+ self._prepare_fn = download_image
+ else:
+ raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
+
+ self.from_user_id = msg.source
+ self.to_user_id = msg.target
+ self.other_user_id = msg.source
diff --git a/channel/wechatmp/README.md b/channel/wechatmp/README.md
new file mode 100644
index 0000000..8d753d8
--- /dev/null
+++ b/channel/wechatmp/README.md
@@ -0,0 +1,100 @@
+# 微信公众号channel
+
+鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。
+目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制。
+
+## 使用方法(订阅号,服务号类似)
+
+在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。
+
+此外,需要在我们的服务器上安装python的web框架web.py和wechatpy。
+以ubuntu为例(在ubuntu 22.04上测试):
+```
+pip3 install web.py
+pip3 install wechatpy
+```
+
+然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。
+
+然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。`URL`填写格式为`http://url/wx`,可使用IP(成功几率看脸),`Token`是你自己编的一个特定的令牌。消息加解密方式如果选择了需要加密的模式,需要在配置中填写`wechatmp_aes_key`。
+
+相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
+```
+"channel_type": "wechatmp", # 如果通过了微信认证,将"wechatmp"替换为"wechatmp_service",可极大的优化使用体验
+"wechatmp_token": "xxxx", # 微信公众平台的Token
+"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
+"wechatmp_app_id": "xxxx", # 微信公众平台的appID
+"wechatmp_app_secret": "xxxx", # 微信公众平台的appsecret
+"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
+"single_chat_prefix": [""], # 推荐设置,任意对话都可以触发回复,不添加前缀
+"single_chat_reply_prefix": "", # 推荐设置,回复不设置前缀
+"plugin_trigger_prefix": "&", # 推荐设置,在手机微信客户端中,$%^等符号与中文连在一起时会自动显示一段较大的间隔,用户体验不好。请不要使用管理员指令前缀"#",这会造成未知问题。
+```
+然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口:
+```
+sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080
+sudo iptables-save > /etc/iptables/rules.v4
+```
+第二个方法是让python程序直接监听80端口,在配置文件中设置`"wechatmp_port": 80` ,在linux上需要使用`sudo python3 app.py`启动程序。然而这会导致一系列环境和权限问题,因此不是推荐的方法。
+
+443端口同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`中需要修改相应的证书路径。
+
+程序启动并监听端口后,在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
+随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。
+
+之后需要在公众号开发信息下将本机IP加入到IP白名单。
+
+不然在启用后,发送语音、图片等消息可能会遇到如下报错:
+```
+'errcode': 40164, 'errmsg': 'invalid ip xx.xx.xx.xx not in whitelist rid
+```
+
+
+## 个人微信公众号的限制
+由于人微信公众号不能通过微信认证,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。
+
+另外,由于微信官方的限制,自动回复有长度限制。因此这里将ChatGPT的回答进行了拆分,以满足限制。
+
+## 私有api_key
+公共api有访问频率限制(免费账号每分钟最多3次ChatGPT的API调用),这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。
+
+## 语音输入
+利用微信自带的语音识别功能,提供语音输入能力。需要在公众号管理页面的“设置与开发”->“接口权限”页面开启“接收语音识别结果”。
+
+## 语音回复
+请在配置文件中添加以下词条:
+```
+ "voice_reply_voice": true,
+```
+这样公众号将会用语音回复语音消息,实现语音对话。
+
+默认的语音合成引擎是`google`,它是免费使用的。
+
+如果要选择其他的语音合成引擎,请添加以下配置项:
+```
+"text_to_voice": "pytts"
+```
+
+pytts是本地的语音合成引擎。还支持baidu,azure,这些你需要自行配置相关的依赖和key。
+
+如果使用pytts,在ubuntu上需要安装如下依赖:
+```
+sudo apt update
+sudo apt install espeak
+sudo apt install ffmpeg
+python3 -m pip install pyttsx3
+```
+不是很建议开启pytts语音回复,因为它是离线本地计算,算的慢会拖垮服务器,且声音不好听。
+
+## 图片回复
+现在认证公众号和非认证公众号都可以实现的图片和语音回复。但是非认证公众号使用了永久素材接口,每天有1000次的调用上限(每个月有10次重置机会,程序中已设定遇到上限会自动重置),且永久素材库存也有上限。因此对于非认证公众号,我们会在回复图片或者语音消息后的10秒内从永久素材库存内删除该素材。
+
+## 测试
+目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp)),感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有详尽测试。百度的接口暂未测试。[wechatmp-stable分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp-stable)是较稳定的上个版本,但也缺少最新的功能支持。
+
+## TODO
+ - [x] 语音输入
+ - [x] 图片输入
+ - [x] 使用临时素材接口提供认证公众号的图片和语音回复
+ - [x] 使用永久素材接口提供未认证公众号的图片和语音回复
+ - [ ] 高并发支持
diff --git a/channel/wechatmp/active_reply.py b/channel/wechatmp/active_reply.py
new file mode 100644
index 0000000..f236981
--- /dev/null
+++ b/channel/wechatmp/active_reply.py
@@ -0,0 +1,75 @@
+import time
+
+import web
+from wechatpy import parse_message
+from wechatpy.replies import create_reply
+
+from bridge.context import *
+from bridge.reply import *
+from channel.wechatmp.common import *
+from channel.wechatmp.wechatmp_channel import WechatMPChannel
+from channel.wechatmp.wechatmp_message import WeChatMPMessage
+from common.log import logger
+from config import conf, subscribe_msg
+
+
+# This class is instantiated once per query
+class Query:
+ def GET(self):
+ return verify_server(web.input())
+
+ def POST(self):
+ # Make sure to return the instance that first created, @singleton will do that.
+ try:
+ args = web.input()
+ verify_server(args)
+ channel = WechatMPChannel()
+ message = web.data()
+ encrypt_func = lambda x: x
+ if args.get("encrypt_type") == "aes":
+ logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
+ if not channel.crypto:
+ raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
+ message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
+ encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
+ else:
+ logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
+ msg = parse_message(message)
+ if msg.type in ["text", "voice", "image"]:
+ wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
+ from_user = wechatmp_msg.from_user_id
+ content = wechatmp_msg.content
+ message_id = wechatmp_msg.msg_id
+
+ logger.info(
+ "[wechatmp] {}:{} Receive post query {} {}: {}".format(
+ web.ctx.env.get("REMOTE_ADDR"),
+ web.ctx.env.get("REMOTE_PORT"),
+ from_user,
+ message_id,
+ content,
+ )
+ )
+ if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
+ else:
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
+ if context:
+ channel.produce(context)
+ # The reply will be sent by channel.send() in another thread
+ return "success"
+ elif msg.type == "event":
+ logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
+ if msg.event in ["subscribe", "subscribe_scan"]:
+ reply_text = subscribe_msg()
+ if reply_text:
+ replyPost = create_reply(reply_text, msg)
+ return encrypt_func(replyPost.render())
+ else:
+ return "success"
+ else:
+ logger.info("暂且不处理")
+ return "success"
+ except Exception as exc:
+ logger.exception(exc)
+ return exc
diff --git a/channel/wechatmp/common.py b/channel/wechatmp/common.py
new file mode 100644
index 0000000..e1cbe7b
--- /dev/null
+++ b/channel/wechatmp/common.py
@@ -0,0 +1,27 @@
+import web
+from wechatpy.crypto import WeChatCrypto
+from wechatpy.exceptions import InvalidSignatureException
+from wechatpy.utils import check_signature
+
+from config import conf
+
+MAX_UTF8_LEN = 2048
+
+
+class WeChatAPIException(Exception):
+ pass
+
+
+def verify_server(data):
+ try:
+ signature = data.signature
+ timestamp = data.timestamp
+ nonce = data.nonce
+ echostr = data.get("echostr", None)
+ token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写
+ check_signature(token, signature, timestamp, nonce)
+ return echostr
+ except InvalidSignatureException:
+ raise web.Forbidden("Invalid signature")
+ except Exception as e:
+ raise web.Forbidden(str(e))
diff --git a/channel/wechatmp/passive_reply.py b/channel/wechatmp/passive_reply.py
new file mode 100644
index 0000000..d03efc4
--- /dev/null
+++ b/channel/wechatmp/passive_reply.py
@@ -0,0 +1,211 @@
+import asyncio
+import time
+
+import web
+from wechatpy import parse_message
+from wechatpy.replies import ImageReply, VoiceReply, create_reply
+import textwrap
+from bridge.context import *
+from bridge.reply import *
+from channel.wechatmp.common import *
+from channel.wechatmp.wechatmp_channel import WechatMPChannel
+from channel.wechatmp.wechatmp_message import WeChatMPMessage
+from common.log import logger
+from common.utils import split_string_by_utf8_length
+from config import conf, subscribe_msg
+
+
+# This class is instantiated once per query
+class Query:
+ def GET(self):
+ return verify_server(web.input())
+
+ def POST(self):
+ try:
+ args = web.input()
+ verify_server(args)
+ request_time = time.time()
+ channel = WechatMPChannel()
+ message = web.data()
+ encrypt_func = lambda x: x
+ if args.get("encrypt_type") == "aes":
+ logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
+ if not channel.crypto:
+ raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
+ message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
+ encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
+ else:
+ logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
+ msg = parse_message(message)
+ if msg.type in ["text", "voice", "image"]:
+ wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
+ from_user = wechatmp_msg.from_user_id
+ content = wechatmp_msg.content
+ message_id = wechatmp_msg.msg_id
+
+ supported = True
+ if "【收到不支持的消息类型,暂无法显示】" in content:
+ supported = False # not supported, used to refresh
+
+ # New request
+ if (
+ channel.cache_dict.get(from_user) is None
+ and from_user not in channel.running
+ or content.startswith("#")
+ and message_id not in channel.request_cnt # insert the godcmd
+ ):
+ # The first query begin
+ if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
+ else:
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
+ logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))
+
+ if supported and context:
+ channel.running.add(from_user)
+ channel.produce(context)
+ else:
+ trigger_prefix = conf().get("single_chat_prefix", [""])[0]
+ if trigger_prefix or not supported:
+ if trigger_prefix:
+ reply_text = textwrap.dedent(
+ f"""\
+ 请输入'{trigger_prefix}'接你想说的话跟我说话。
+ 例如:
+ {trigger_prefix}你好,很高兴见到你。"""
+ )
+ else:
+ reply_text = textwrap.dedent(
+ """\
+ 你好,很高兴见到你。
+ 请跟我说话吧。"""
+ )
+ else:
+ logger.error(f"[wechatmp] unknown error")
+ reply_text = textwrap.dedent(
+ """\
+ 未知错误,请稍后再试"""
+ )
+
+ replyPost = create_reply(reply_text, msg)
+ return encrypt_func(replyPost.render())
+
+ # Wechat official server will request 3 times (5 seconds each), with the same message_id.
+ # Because the interval is 5 seconds, here assumed that do not have multithreading problems.
+ request_cnt = channel.request_cnt.get(message_id, 0) + 1
+ channel.request_cnt[message_id] = request_cnt
+ logger.info(
+ "[wechatmp] Request {} from {} {} {}:{}\n{}".format(
+ request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
+ )
+ )
+
+ task_running = True
+ waiting_until = request_time + 4
+ while time.time() < waiting_until:
+ if from_user in channel.running:
+ time.sleep(0.1)
+ else:
+ task_running = False
+ break
+
+ reply_text = ""
+ if task_running:
+ if request_cnt < 3:
+ # waiting for timeout (the POST request will be closed by Wechat official server)
+ time.sleep(2)
+ # and do nothing, waiting for the next request
+ return "success"
+ else: # request_cnt == 3:
+ # return timeout message
+ reply_text = "【正在思考中,回复任意文字尝试获取回复】"
+ replyPost = create_reply(reply_text, msg)
+ return encrypt_func(replyPost.render())
+
+ # reply is ready
+ channel.request_cnt.pop(message_id)
+
+ # no return because of bandwords or other reasons
+ if from_user not in channel.cache_dict and from_user not in channel.running:
+ return "success"
+
+ # Only one request can access to the cached data
+ try:
+ (reply_type, reply_content) = channel.cache_dict[from_user].pop(0)
+ if not channel.cache_dict[from_user]: # If popping the message makes the list empty, delete the user entry from cache
+ del channel.cache_dict[from_user]
+ except IndexError:
+ return "success"
+
+ if reply_type == "text":
+ if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
+ reply_text = reply_content
+ else:
+ continue_text = "\n【未完待续,回复任意文字以继续】"
+ splits = split_string_by_utf8_length(
+ reply_content,
+ MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
+ max_split=1,
+ )
+ reply_text = splits[0] + continue_text
+ channel.cache_dict[from_user].append(("text", splits[1]))
+
+ logger.info(
+ "[wechatmp] Request {} do send to {} {}: {}\n{}".format(
+ request_cnt,
+ from_user,
+ message_id,
+ content,
+ reply_text,
+ )
+ )
+ replyPost = create_reply(reply_text, msg)
+ return encrypt_func(replyPost.render())
+
+ elif reply_type == "voice":
+ media_id = reply_content
+ asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
+ logger.info(
+ "[wechatmp] Request {} do send to {} {}: {} voice media_id {}".format(
+ request_cnt,
+ from_user,
+ message_id,
+ content,
+ media_id,
+ )
+ )
+ replyPost = VoiceReply(message=msg)
+ replyPost.media_id = media_id
+ return encrypt_func(replyPost.render())
+
+ elif reply_type == "image":
+ media_id = reply_content
+ asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
+ logger.info(
+ "[wechatmp] Request {} do send to {} {}: {} image media_id {}".format(
+ request_cnt,
+ from_user,
+ message_id,
+ content,
+ media_id,
+ )
+ )
+ replyPost = ImageReply(message=msg)
+ replyPost.media_id = media_id
+ return encrypt_func(replyPost.render())
+
+ elif msg.type == "event":
+ logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
+ if msg.event in ["subscribe", "subscribe_scan"]:
+ reply_text = subscribe_msg()
+ if reply_text:
+ replyPost = create_reply(reply_text, msg)
+ return encrypt_func(replyPost.render())
+ else:
+ return "success"
+ else:
+ logger.info("暂且不处理")
+ return "success"
+ except Exception as exc:
+ logger.exception(exc)
+ return exc
diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py
new file mode 100644
index 0000000..2ae8822
--- /dev/null
+++ b/channel/wechatmp/wechatmp_channel.py
@@ -0,0 +1,236 @@
+# -*- coding: utf-8 -*-
+import asyncio
+import imghdr
+import io
+import os
+import threading
+import time
+
+import requests
+import web
+from wechatpy.crypto import WeChatCrypto
+from wechatpy.exceptions import WeChatClientException
+from collections import defaultdict
+
+from bridge.context import *
+from bridge.reply import *
+from channel.chat_channel import ChatChannel
+from channel.wechatmp.common import *
+from channel.wechatmp.wechatmp_client import WechatMPClient
+from common.log import logger
+from common.singleton import singleton
+from common.utils import split_string_by_utf8_length
+from config import conf
+from voice.audio_convert import any_to_mp3, split_audio
+
+# If using SSL, uncomment the following lines, and modify the certificate path.
+# from cheroot.server import HTTPServer
+# from cheroot.ssl.builtin import BuiltinSSLAdapter
+# HTTPServer.ssl_adapter = BuiltinSSLAdapter(
+# certificate='/ssl/cert.pem',
+# private_key='/ssl/cert.key')
+
+
+@singleton
+class WechatMPChannel(ChatChannel):
+ def __init__(self, passive_reply=True):
+ super().__init__()
+ self.passive_reply = passive_reply
+ self.NOT_SUPPORT_REPLYTYPE = []
+ appid = conf().get("wechatmp_app_id")
+ secret = conf().get("wechatmp_app_secret")
+ token = conf().get("wechatmp_token")
+ aes_key = conf().get("wechatmp_aes_key")
+ self.client = WechatMPClient(appid, secret)
+ self.crypto = None
+ if aes_key:
+ self.crypto = WeChatCrypto(token, aes_key, appid)
+ if self.passive_reply:
+ # Cache the reply to the user's first message
+ self.cache_dict = defaultdict(list)
+ # Record whether the current message is being processed
+ self.running = set()
+ # Count the request from wechat official server by message_id
+ self.request_cnt = dict()
+ # The permanent media need to be deleted to avoid media number limit
+ self.delete_media_loop = asyncio.new_event_loop()
+ t = threading.Thread(target=self.start_loop, args=(self.delete_media_loop,))
+ t.setDaemon(True)
+ t.start()
+
+ def startup(self):
+ if self.passive_reply:
+ urls = ("/wx", "channel.wechatmp.passive_reply.Query")
+ else:
+ urls = ("/wx", "channel.wechatmp.active_reply.Query")
+ app = web.application(urls, globals(), autoreload=False)
+ port = conf().get("wechatmp_port", 8080)
+ web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
+
+ def start_loop(self, loop):
+ asyncio.set_event_loop(loop)
+ loop.run_forever()
+
+ async def delete_media(self, media_id):
+ logger.debug("[wechatmp] permanent media {} will be deleted in 10s".format(media_id))
+ await asyncio.sleep(10)
+ self.client.material.delete(media_id)
+ logger.info("[wechatmp] permanent media {} has been deleted".format(media_id))
+
+ def send(self, reply: Reply, context: Context):
+ receiver = context["receiver"]
+ if self.passive_reply:
+ if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
+ reply_text = reply.content
+ logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
+ self.cache_dict[receiver].append(("text", reply_text))
+ elif reply.type == ReplyType.VOICE:
+ voice_file_path = reply.content
+ duration, files = split_audio(voice_file_path, 60 * 1000)
+ if len(files) > 1:
+ logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
+
+ for path in files:
+ # support: <2M, <60s, mp3/wma/wav/amr
+ try:
+ with open(path, "rb") as f:
+ response = self.client.material.add("voice", f)
+ logger.debug("[wechatmp] upload voice response: {}".format(response))
+ f_size = os.fstat(f.fileno()).st_size
+ time.sleep(1.0 + 2 * f_size / 1024 / 1024)
+ # todo check media_id
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload voice failed: {}".format(e))
+ return
+ media_id = response["media_id"]
+ logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
+ self.cache_dict[receiver].append(("voice", media_id))
+
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+ img_url = reply.content
+ pic_res = requests.get(img_url, stream=True)
+ image_storage = io.BytesIO()
+ for block in pic_res.iter_content(1024):
+ image_storage.write(block)
+ image_storage.seek(0)
+ image_type = imghdr.what(image_storage)
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
+ content_type = "image/" + image_type
+ try:
+ response = self.client.material.add("image", (filename, image_storage, content_type))
+ logger.debug("[wechatmp] upload image response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload image failed: {}".format(e))
+ return
+ media_id = response["media_id"]
+ logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
+ self.cache_dict[receiver].append(("image", media_id))
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
+ image_storage = reply.content
+ image_storage.seek(0)
+ image_type = imghdr.what(image_storage)
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
+ content_type = "image/" + image_type
+ try:
+ response = self.client.material.add("image", (filename, image_storage, content_type))
+ logger.debug("[wechatmp] upload image response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload image failed: {}".format(e))
+ return
+ media_id = response["media_id"]
+ logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
+ self.cache_dict[receiver].append(("image", media_id))
+ else:
+ if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
+ reply_text = reply.content
+ texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
+ if len(texts) > 1:
+ logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
+ for i, text in enumerate(texts):
+ self.client.message.send_text(receiver, text)
+ if i != len(texts) - 1:
+ time.sleep(0.5) # 休眠0.5秒,防止发送过快乱序
+ logger.info("[wechatmp] Do send text to {}: {}".format(receiver, reply_text))
+ elif reply.type == ReplyType.VOICE:
+ try:
+ file_path = reply.content
+ file_name = os.path.basename(file_path)
+ file_type = os.path.splitext(file_name)[1]
+ if file_type == ".mp3":
+ file_type = "audio/mpeg"
+ elif file_type == ".amr":
+ file_type = "audio/amr"
+ else:
+ mp3_file = os.path.splitext(file_path)[0] + ".mp3"
+ any_to_mp3(file_path, mp3_file)
+ file_path = mp3_file
+ file_name = os.path.basename(file_path)
+ file_type = "audio/mpeg"
+ logger.info("[wechatmp] file_name: {}, file_type: {} ".format(file_name, file_type))
+ media_ids = []
+ duration, files = split_audio(file_path, 60 * 1000)
+ if len(files) > 1:
+ logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
+ for path in files:
+ # support: <2M, <60s, AMR\MP3
+ response = self.client.media.upload("voice", (os.path.basename(path), open(path, "rb"), file_type))
+ logger.debug("[wechatcom] upload voice response: {}".format(response))
+ media_ids.append(response["media_id"])
+ os.remove(path)
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload voice failed: {}".format(e))
+ return
+
+ try:
+ os.remove(file_path)
+ except Exception:
+ pass
+
+ for media_id in media_ids:
+ self.client.message.send_voice(receiver, media_id)
+ time.sleep(1)
+ logger.info("[wechatmp] Do send voice to {}".format(receiver))
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+ img_url = reply.content
+ pic_res = requests.get(img_url, stream=True)
+ image_storage = io.BytesIO()
+ for block in pic_res.iter_content(1024):
+ image_storage.write(block)
+ image_storage.seek(0)
+ image_type = imghdr.what(image_storage)
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
+ content_type = "image/" + image_type
+ try:
+ response = self.client.media.upload("image", (filename, image_storage, content_type))
+ logger.debug("[wechatmp] upload image response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload image failed: {}".format(e))
+ return
+ self.client.message.send_image(receiver, response["media_id"])
+ logger.info("[wechatmp] Do send image to {}".format(receiver))
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
+ image_storage = reply.content
+ image_storage.seek(0)
+ image_type = imghdr.what(image_storage)
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
+ content_type = "image/" + image_type
+ try:
+ response = self.client.media.upload("image", (filename, image_storage, content_type))
+ logger.debug("[wechatmp] upload image response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload image failed: {}".format(e))
+ return
+ self.client.message.send_image(receiver, response["media_id"])
+ logger.info("[wechatmp] Do send image to {}".format(receiver))
+ return
+
+ def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
+ logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
+ if self.passive_reply:
+ self.running.remove(session_id)
+
+ def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
+ logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
+ if self.passive_reply:
+ assert session_id not in self.cache_dict
+ self.running.remove(session_id)
diff --git a/channel/wechatmp/wechatmp_client.py b/channel/wechatmp/wechatmp_client.py
new file mode 100644
index 0000000..19dca32
--- /dev/null
+++ b/channel/wechatmp/wechatmp_client.py
@@ -0,0 +1,49 @@
+import threading
+import time
+
+from wechatpy.client import WeChatClient
+from wechatpy.exceptions import APILimitedException
+
+from channel.wechatmp.common import *
+from common.log import logger
+
+
+class WechatMPClient(WeChatClient):
+ def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
+ super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
+ self.fetch_access_token_lock = threading.Lock()
+ self.clear_quota_lock = threading.Lock()
+ self.last_clear_quota_time = -1
+
+ def clear_quota(self):
+ return self.post("clear_quota", data={"appid": self.appid})
+
+ def clear_quota_v2(self):
+ return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})
+
+ def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
+ with self.fetch_access_token_lock:
+ access_token = self.session.get(self.access_token_key)
+ if access_token:
+ if not self.expires_at:
+ return access_token
+ timestamp = time.time()
+ if self.expires_at - timestamp > 60:
+ return access_token
+ return super().fetch_access_token()
+
+ def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
+ try:
+ return super()._request(method, url_or_endpoint, **kwargs)
+ except APILimitedException as e:
+ logger.error("[wechatmp] API quata has been used up. {}".format(e))
+ if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
+ with self.clear_quota_lock:
+ if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
+ self.last_clear_quota_time = time.time()
+ response = self.clear_quota_v2()
+ logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
+ return super()._request(method, url_or_endpoint, **kwargs)
+ else:
+ logger.error("[wechatmp] last clear quota time is {}, less than 60s, skip clear quota")
+ raise e
diff --git a/channel/wechatmp/wechatmp_message.py b/channel/wechatmp/wechatmp_message.py
new file mode 100644
index 0000000..27c9fbb
--- /dev/null
+++ b/channel/wechatmp/wechatmp_message.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-#
+
+from bridge.context import ContextType
+from channel.chat_message import ChatMessage
+from common.log import logger
+from common.tmp_dir import TmpDir
+
+
+class WeChatMPMessage(ChatMessage):
+ def __init__(self, msg, client=None):
+ super().__init__(msg)
+ self.msg_id = msg.id
+ self.create_time = msg.time
+ self.is_group = False
+
+ if msg.type == "text":
+ self.ctype = ContextType.TEXT
+ self.content = msg.content
+ elif msg.type == "voice":
+ if msg.recognition == None:
+ self.ctype = ContextType.VOICE
+ self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
+
+ def download_voice():
+ # 如果响应状态码是200,则将响应内容写入本地文件
+ response = client.media.download(msg.media_id)
+ if response.status_code == 200:
+ with open(self.content, "wb") as f:
+ f.write(response.content)
+ else:
+ logger.info(f"[wechatmp] Failed to download voice file, {response.content}")
+
+ self._prepare_fn = download_voice
+ else:
+ self.ctype = ContextType.TEXT
+ self.content = msg.recognition
+ elif msg.type == "image":
+ self.ctype = ContextType.IMAGE
+ self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
+
+ def download_image():
+ # 如果响应状态码是200,则将响应内容写入本地文件
+ response = client.media.download(msg.media_id)
+ if response.status_code == 200:
+ with open(self.content, "wb") as f:
+ f.write(response.content)
+ else:
+ logger.info(f"[wechatmp] Failed to download image file, {response.content}")
+
+ self._prepare_fn = download_image
+ else:
+ raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
+
+ self.from_user_id = msg.source
+ self.to_user_id = msg.target
+ self.other_user_id = msg.source
diff --git a/channel/wework/run.py b/channel/wework/run.py
new file mode 100644
index 0000000..1e7d5b3
--- /dev/null
+++ b/channel/wework/run.py
@@ -0,0 +1,17 @@
+import os
+import time
+os.environ['ntwork_LOG'] = "ERROR"
+import ntwork
+
+wework = ntwork.WeWork()
+
+
+def forever():
+ try:
+ while True:
+ time.sleep(0.1)
+ except KeyboardInterrupt:
+ ntwork.exit_()
+ os._exit(0)
+
+
diff --git a/channel/wework/wework_channel.py b/channel/wework/wework_channel.py
new file mode 100644
index 0000000..1020261
--- /dev/null
+++ b/channel/wework/wework_channel.py
@@ -0,0 +1,326 @@
+import io
+import os
+import random
+import tempfile
+import threading
+os.environ['ntwork_LOG'] = "ERROR"
+import ntwork
+import requests
+import uuid
+
+from bridge.context import *
+from bridge.reply import *
+from channel.chat_channel import ChatChannel
+from channel.wework.wework_message import *
+from channel.wework.wework_message import WeworkMessage
+from common.singleton import singleton
+from common.log import logger
+from common.time_check import time_checker
+from common.utils import compress_imgfile, fsize
+from config import conf
+from channel.wework.run import wework
+from channel.wework import run
+from PIL import Image
+
+
+def get_wxid_by_name(room_members, group_wxid, name):
+ if group_wxid in room_members:
+ for member in room_members[group_wxid]['member_list']:
+ if member['room_nickname'] == name or member['username'] == name:
+ return member['user_id']
+ return None # 如果没有找到对应的group_wxid或name,则返回None
+
+
+def download_and_compress_image(url, filename, quality=30):
+ # 确定保存图片的目录
+ directory = os.path.join(os.getcwd(), "tmp")
+ # 如果目录不存在,则创建目录
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ # 下载图片
+ pic_res = requests.get(url, stream=True)
+ image_storage = io.BytesIO()
+ for block in pic_res.iter_content(1024):
+ image_storage.write(block)
+
+ # 检查图片大小并可能进行压缩
+ sz = fsize(image_storage)
+ if sz >= 10 * 1024 * 1024: # 如果图片大于 10 MB
+ logger.info("[wework] image too large, ready to compress, sz={}".format(sz))
+ image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
+ logger.info("[wework] image compressed, sz={}".format(fsize(image_storage)))
+
+ # 将内存缓冲区的指针重置到起始位置
+ image_storage.seek(0)
+
+ # 读取并保存图片
+ image = Image.open(image_storage)
+ image_path = os.path.join(directory, f"{filename}.png")
+ image.save(image_path, "png")
+
+ return image_path
+
+
+def download_video(url, filename):
+ # 确定保存视频的目录
+ directory = os.path.join(os.getcwd(), "tmp")
+ # 如果目录不存在,则创建目录
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ # 下载视频
+ response = requests.get(url, stream=True)
+ total_size = 0
+
+ video_path = os.path.join(directory, f"{filename}.mp4")
+
+ with open(video_path, 'wb') as f:
+ for block in response.iter_content(1024):
+ total_size += len(block)
+
+ # 如果视频的总大小超过30MB (30 * 1024 * 1024 bytes),则停止下载并返回
+ if total_size > 30 * 1024 * 1024:
+ logger.info("[WX] Video is larger than 30MB, skipping...")
+ return None
+
+ f.write(block)
+
+ return video_path
+
+
+def create_message(wework_instance, message, is_group):
+ logger.debug(f"正在为{'群聊' if is_group else '单聊'}创建 WeworkMessage")
+ cmsg = WeworkMessage(message, wework=wework_instance, is_group=is_group)
+ logger.debug(f"cmsg:{cmsg}")
+ return cmsg
+
+
+def handle_message(cmsg, is_group):
+ logger.debug(f"准备用 WeworkChannel 处理{'群聊' if is_group else '单聊'}消息")
+ if is_group:
+ WeworkChannel().handle_group(cmsg)
+ else:
+ WeworkChannel().handle_single(cmsg)
+ logger.debug(f"已用 WeworkChannel 处理完{'群聊' if is_group else '单聊'}消息")
+
+
+def _check(func):
+ def wrapper(self, cmsg: ChatMessage):
+ msgId = cmsg.msg_id
+ create_time = cmsg.create_time # 消息时间戳
+ if create_time is None:
+ return func(self, cmsg)
+ if int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
+ logger.debug("[WX]history message {} skipped".format(msgId))
+ return
+ return func(self, cmsg)
+
+ return wrapper
+
+
+@wework.msg_register(
+ [ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_LINK_CARD_MSG,ntwork.MT_RECV_FILE_MSG, ntwork.MT_RECV_VOICE_MSG])
+def all_msg_handler(wework_instance: ntwork.WeWork, message):
+ logger.debug(f"收到消息: {message}")
+ if 'data' in message:
+ # 首先查找conversation_id,如果没有找到,则查找room_conversation_id
+ conversation_id = message['data'].get('conversation_id', message['data'].get('room_conversation_id'))
+ if conversation_id is not None:
+ is_group = "R:" in conversation_id
+ try:
+ cmsg = create_message(wework_instance=wework_instance, message=message, is_group=is_group)
+ except NotImplementedError as e:
+ logger.error(f"[WX]{message.get('MsgId', 'unknown')} 跳过: {e}")
+ return None
+ delay = random.randint(1, 2)
+ timer = threading.Timer(delay, handle_message, args=(cmsg, is_group))
+ timer.start()
+ else:
+ logger.debug("消息数据中无 conversation_id")
+ return None
+ return None
+
+
+def accept_friend_with_retries(wework_instance, user_id, corp_id):
+ result = wework_instance.accept_friend(user_id, corp_id)
+ logger.debug(f'result:{result}')
+
+
+# @wework.msg_register(ntwork.MT_RECV_FRIEND_MSG)
+# def friend(wework_instance: ntwork.WeWork, message):
+# data = message["data"]
+# user_id = data["user_id"]
+# corp_id = data["corp_id"]
+# logger.info(f"接收到好友请求,消息内容:{data}")
+# delay = random.randint(1, 180)
+# threading.Timer(delay, accept_friend_with_retries, args=(wework_instance, user_id, corp_id)).start()
+#
+# return None
+
+
+def get_with_retry(get_func, max_retries=5, delay=5):
+ retries = 0
+ result = None
+ while retries < max_retries:
+ result = get_func()
+ if result:
+ break
+ logger.warning(f"获取数据失败,重试第{retries + 1}次······")
+ retries += 1
+ time.sleep(delay) # 等待一段时间后重试
+ return result
+
+
+@singleton
+class WeworkChannel(ChatChannel):
+ NOT_SUPPORT_REPLYTYPE = []
+
+ def __init__(self):
+ super().__init__()
+
+ def startup(self):
+ smart = conf().get("wework_smart", True)
+ wework.open(smart)
+ logger.info("等待登录······")
+ wework.wait_login()
+ login_info = wework.get_login_info()
+ self.user_id = login_info['user_id']
+ self.name = login_info['nickname']
+ logger.info(f"登录信息:>>>user_id:{self.user_id}>>>>>>>>name:{self.name}")
+ logger.info("静默延迟60s,等待客户端刷新数据,请勿进行任何操作······")
+ time.sleep(60)
+ contacts = get_with_retry(wework.get_external_contacts)
+ rooms = get_with_retry(wework.get_rooms)
+ directory = os.path.join(os.getcwd(), "tmp")
+ if not contacts or not rooms:
+ logger.error("获取contacts或rooms失败,程序退出")
+ ntwork.exit_()
+ os.exit(0)
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ # 将contacts保存到json文件中
+ with open(os.path.join(directory, 'wework_contacts.json'), 'w', encoding='utf-8') as f:
+ json.dump(contacts, f, ensure_ascii=False, indent=4)
+ with open(os.path.join(directory, 'wework_rooms.json'), 'w', encoding='utf-8') as f:
+ json.dump(rooms, f, ensure_ascii=False, indent=4)
+ # 创建一个空字典来保存结果
+ result = {}
+
+ # 遍历列表中的每个字典
+ for room in rooms['room_list']:
+ # 获取聊天室ID
+ room_wxid = room['conversation_id']
+
+ # 获取聊天室成员
+ room_members = wework.get_room_members(room_wxid)
+
+ # 将聊天室成员保存到结果字典中
+ result[room_wxid] = room_members
+
+ # 将结果保存到json文件中
+ with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
+ json.dump(result, f, ensure_ascii=False, indent=4)
+ logger.info("wework程序初始化完成········")
+ run.forever()
+
+ @time_checker
+ @_check
+ def handle_single(self, cmsg: ChatMessage):
+ if cmsg.from_user_id == cmsg.to_user_id:
+ # ignore self reply
+ return
+ if cmsg.ctype == ContextType.VOICE:
+ if not conf().get("speech_recognition"):
+ return
+ logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.IMAGE:
+ logger.debug("[WX]receive image msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.PATPAT:
+ logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.TEXT:
+ logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
+ else:
+ logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
+ if context:
+ self.produce(context)
+
+ @time_checker
+ @_check
+ def handle_group(self, cmsg: ChatMessage):
+ if cmsg.ctype == ContextType.VOICE:
+ if not conf().get("speech_recognition"):
+ return
+ logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.IMAGE:
+ logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
+ elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
+ logger.debug("[WX]receive note msg: {}".format(cmsg.content))
+ elif cmsg.ctype == ContextType.TEXT:
+ pass
+ else:
+ logger.debug("[WX]receive group msg: {}".format(cmsg.content))
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
+ if context:
+ self.produce(context)
+
+ # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
+ def send(self, reply: Reply, context: Context):
+ logger.debug(f"context: {context}")
+ receiver = context["receiver"]
+ actual_user_id = context["msg"].actual_user_id
+ if reply.type == ReplyType.TEXT or reply.type == ReplyType.TEXT_:
+ match = re.search(r"^@(.*?)\n", reply.content)
+ logger.debug(f"match: {match}")
+ if match:
+ new_content = re.sub(r"^@(.*?)\n", "\n", reply.content)
+ at_list = [actual_user_id]
+ logger.debug(f"new_content: {new_content}")
+ wework.send_room_at_msg(receiver, new_content, at_list)
+ else:
+ wework.send_text(receiver, reply.content)
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
+ elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
+ wework.send_text(receiver, reply.content)
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
+ image_storage = reply.content
+ image_storage.seek(0)
+ # Read data from image_storage
+ data = image_storage.read()
+ # Create a temporary file
+ with tempfile.NamedTemporaryFile(delete=False) as temp:
+ temp_path = temp.name
+ temp.write(data)
+ # Send the image
+ wework.send_image(receiver, temp_path)
+ logger.info("[WX] sendImage, receiver={}".format(receiver))
+ # Remove the temporary file
+ os.remove(temp_path)
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
+ img_url = reply.content
+ filename = str(uuid.uuid4())
+
+ # 调用你的函数,下载图片并保存为本地文件
+ image_path = download_and_compress_image(img_url, filename)
+
+ wework.send_image(receiver, file_path=image_path)
+ logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
+ elif reply.type == ReplyType.VIDEO_URL:
+ video_url = reply.content
+ filename = str(uuid.uuid4())
+ video_path = download_video(video_url, filename)
+
+ if video_path is None:
+ # 如果视频太大,下载可能会被跳过,此时 video_path 将为 None
+ wework.send_text(receiver, "抱歉,视频太大了!!!")
+ else:
+ wework.send_video(receiver, video_path)
+ logger.info("[WX] sendVideo, receiver={}".format(receiver))
+ elif reply.type == ReplyType.VOICE:
+ current_dir = os.getcwd()
+ voice_file = reply.content.split("/")[-1]
+ reply.content = os.path.join(current_dir, "tmp", voice_file)
+ wework.send_file(receiver, reply.content)
+ logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
diff --git a/channel/wework/wework_message.py b/channel/wework/wework_message.py
new file mode 100644
index 0000000..0d9e96e
--- /dev/null
+++ b/channel/wework/wework_message.py
@@ -0,0 +1,227 @@
+import datetime
+import json
+import os
+import re
+import time
+import pilk
+
+from bridge.context import ContextType
+from channel.chat_message import ChatMessage
+from common.log import logger
+from ntwork.const import send_type
+
+
+def get_with_retry(get_func, max_retries=5, delay=5):
+ retries = 0
+ result = None
+ while retries < max_retries:
+ result = get_func()
+ if result:
+ break
+ logger.warning(f"获取数据失败,重试第{retries + 1}次······")
+ retries += 1
+ time.sleep(delay) # 等待一段时间后重试
+ return result
+
+
+def get_room_info(wework, conversation_id):
+ logger.debug(f"传入的 conversation_id: {conversation_id}")
+ rooms = wework.get_rooms()
+ if not rooms or 'room_list' not in rooms:
+ logger.error(f"获取群聊信息失败: {rooms}")
+ return None
+ time.sleep(1)
+ logger.debug(f"获取到的群聊信息: {rooms}")
+ for room in rooms['room_list']:
+ if room['conversation_id'] == conversation_id:
+ return room
+ return None
+
+
+def cdn_download(wework, message, file_name):
+ data = message["data"]
+ aes_key = data["cdn"]["aes_key"]
+ file_size = data["cdn"]["size"]
+
+ # 获取当前工作目录,然后与文件名拼接得到保存路径
+ current_dir = os.getcwd()
+ save_path = os.path.join(current_dir, "tmp", file_name)
+
+ # 下载保存图片到本地
+ if "url" in data["cdn"].keys() and "auth_key" in data["cdn"].keys():
+ url = data["cdn"]["url"]
+ auth_key = data["cdn"]["auth_key"]
+ # result = wework.wx_cdn_download(url, auth_key, aes_key, file_size, save_path) # ntwork库本身接口有问题,缺失了aes_key这个参数
+ """
+ 下载wx类型的cdn文件,以https开头
+ """
+ data = {
+ 'url': url,
+ 'auth_key': auth_key,
+ 'aes_key': aes_key,
+ 'size': file_size,
+ 'save_path': save_path
+ }
+ result = wework._WeWork__send_sync(send_type.MT_WXCDN_DOWNLOAD_MSG, data) # 直接用wx_cdn_download的接口内部实现来调用
+ elif "file_id" in data["cdn"].keys():
+ if message["type"] == 11042:
+ file_type = 2
+ elif message["type"] == 11045:
+ file_type = 5
+ file_id = data["cdn"]["file_id"]
+ result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
+ else:
+ logger.error(f"something is wrong, data: {data}")
+ return
+
+ # 输出下载结果
+ logger.debug(f"result: {result}")
+
+
+def c2c_download_and_convert(wework, message, file_name):
+ data = message["data"]
+ aes_key = data["cdn"]["aes_key"]
+ file_size = data["cdn"]["size"]
+ file_type = 5
+ file_id = data["cdn"]["file_id"]
+
+ current_dir = os.getcwd()
+ save_path = os.path.join(current_dir, "tmp", file_name)
+ result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
+ logger.debug(result)
+
+ # 在下载完SILK文件之后,立即将其转换为WAV文件
+ base_name, _ = os.path.splitext(save_path)
+ wav_file = base_name + ".wav"
+ pilk.silk_to_wav(save_path, wav_file, rate=24000)
+
+ # 删除SILK文件
+ try:
+ os.remove(save_path)
+ except Exception as e:
+ pass
+
+
+class WeworkMessage(ChatMessage):
+ def __init__(self, wework_msg, wework, is_group=False):
+ try:
+ super().__init__(wework_msg)
+ self.msg_id = wework_msg['data'].get('conversation_id', wework_msg['data'].get('room_conversation_id'))
+ # 使用.get()防止 'send_time' 键不存在时抛出错误
+ self.create_time = wework_msg['data'].get("send_time")
+ self.is_group = is_group
+ self.wework = wework
+
+ if wework_msg["type"] == 11041: # 文本消息类型
+ if any(substring in wework_msg['data']['content'] for substring in ("该消息类型暂不能展示", "不支持的消息类型")):
+ return
+ self.ctype = ContextType.TEXT
+ self.content = wework_msg['data']['content']
+ elif wework_msg["type"] == 11044: # 语音消息类型,需要缓存文件
+ file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".silk"
+ base_name, _ = os.path.splitext(file_name)
+ file_name_2 = base_name + ".wav"
+ current_dir = os.getcwd()
+ self.ctype = ContextType.VOICE
+ self.content = os.path.join(current_dir, "tmp", file_name_2)
+ self._prepare_fn = lambda: c2c_download_and_convert(wework, wework_msg, file_name)
+ elif wework_msg["type"] == 11042: # 图片消息类型,需要下载文件
+ file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".jpg"
+ current_dir = os.getcwd()
+ self.ctype = ContextType.IMAGE
+ self.content = os.path.join(current_dir, "tmp", file_name)
+ self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
+ elif wework_msg["type"] == 11045: # 文件消息
+ print("文件消息")
+ print(wework_msg)
+ file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
+ file_name = file_name + wework_msg['data']['cdn']['file_name']
+ current_dir = os.getcwd()
+ self.ctype = ContextType.FILE
+ self.content = os.path.join(current_dir, "tmp", file_name)
+ self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
+ elif wework_msg["type"] == 11047: # 链接消息
+ self.ctype = ContextType.SHARING
+ self.content = wework_msg['data']['url']
+ elif wework_msg["type"] == 11072: # 新成员入群通知
+ self.ctype = ContextType.JOIN_GROUP
+ member_list = wework_msg['data']['member_list']
+ self.actual_user_nickname = member_list[0]['name']
+ self.actual_user_id = member_list[0]['user_id']
+ self.content = f"{self.actual_user_nickname}加入了群聊!"
+ directory = os.path.join(os.getcwd(), "tmp")
+ rooms = get_with_retry(wework.get_rooms)
+ if not rooms:
+ logger.error("更新群信息失败···")
+ else:
+ result = {}
+ for room in rooms['room_list']:
+ # 获取聊天室ID
+ room_wxid = room['conversation_id']
+
+ # 获取聊天室成员
+ room_members = wework.get_room_members(room_wxid)
+
+ # 将聊天室成员保存到结果字典中
+ result[room_wxid] = room_members
+ with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
+ json.dump(result, f, ensure_ascii=False, indent=4)
+ logger.info("有新成员加入,已自动更新群成员列表缓存!")
+ else:
+ raise NotImplementedError(
+ "Unsupported message type: Type:{} MsgType:{}".format(wework_msg["type"], wework_msg["MsgType"]))
+
+ data = wework_msg['data']
+ login_info = self.wework.get_login_info()
+ logger.debug(f"login_info: {login_info}")
+ nickname = f"{login_info['username']}({login_info['nickname']})" if login_info['nickname'] else login_info['username']
+ user_id = login_info['user_id']
+
+ sender_id = data.get('sender')
+ conversation_id = data.get('conversation_id')
+ sender_name = data.get("sender_name")
+
+ self.from_user_id = user_id if sender_id == user_id else conversation_id
+ self.from_user_nickname = nickname if sender_id == user_id else sender_name
+ self.to_user_id = user_id
+ self.to_user_nickname = nickname
+ self.other_user_nickname = sender_name
+ self.other_user_id = conversation_id
+
+ if self.is_group:
+ conversation_id = data.get('conversation_id') or data.get('room_conversation_id')
+ self.other_user_id = conversation_id
+ if conversation_id:
+ room_info = get_room_info(wework=wework, conversation_id=conversation_id)
+ self.other_user_nickname = room_info.get('nickname', None) if room_info else None
+ self.from_user_nickname = room_info.get('nickname', None) if room_info else None
+ at_list = data.get('at_list', [])
+ tmp_list = []
+ for at in at_list:
+ tmp_list.append(at['nickname'])
+ at_list = tmp_list
+ logger.debug(f"at_list: {at_list}")
+ logger.debug(f"nickname: {nickname}")
+ self.is_at = False
+ if nickname in at_list or login_info['nickname'] in at_list or login_info['username'] in at_list:
+ self.is_at = True
+ self.at_list = at_list
+
+ # 检查消息内容是否包含@用户名。处理复制粘贴的消息,这类消息可能不会触发@通知,但内容中可能包含 "@用户名"。
+ content = data.get('content', '')
+ name = nickname
+ pattern = f"@{re.escape(name)}(\u2005|\u0020)"
+ if re.search(pattern, content):
+ logger.debug(f"Wechaty message {self.msg_id} includes at")
+ self.is_at = True
+
+ if not self.actual_user_id:
+ self.actual_user_id = data.get("sender")
+ self.actual_user_nickname = sender_name if self.ctype != ContextType.JOIN_GROUP else self.actual_user_nickname
+ else:
+ logger.error("群聊消息中没有找到 conversation_id 或 room_conversation_id")
+
+ logger.debug(f"WeworkMessage has been successfully instantiated with message id: {self.msg_id}")
+ except Exception as e:
+ logger.error(f"在 WeworkMessage 的初始化过程中出现错误:{e}")
+ raise e
diff --git a/common/const.py b/common/const.py
new file mode 100644
index 0000000..aeb9dcc
--- /dev/null
+++ b/common/const.py
@@ -0,0 +1,28 @@
+# bot_type
+OPEN_AI = "openAI"
+CHATGPT = "chatGPT"
+BAIDU = "baidu"
+XUNFEI = "xunfei"
+CHATGPTONAZURE = "chatGPTOnAzure"
+LINKAI = "linkai"
+CLAUDEAI = "claude"
+QWEN = "qwen"
+GEMINI = "gemini"
+ZHIPU_AI = "glm-4"
+
+
+# model
+GPT35 = "gpt-3.5-turbo"
+GPT4 = "gpt-4"
+GPT4_TURBO_PREVIEW = "gpt-4-0125-preview"
+GPT4_VISION_PREVIEW = "gpt-4-vision-preview"
+WHISPER_1 = "whisper-1"
+TTS_1 = "tts-1"
+TTS_1_HD = "tts-1-hd"
+
+MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo",
+ "gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI]
+
+# channel
+FEISHU = "feishu"
+DINGTALK = "dingtalk"
diff --git a/common/dequeue.py b/common/dequeue.py
new file mode 100644
index 0000000..39baf58
--- /dev/null
+++ b/common/dequeue.py
@@ -0,0 +1,33 @@
+from queue import Full, Queue
+from time import monotonic as time
+
+
+# add implementation of putleft to Queue
+class Dequeue(Queue):
+ def putleft(self, item, block=True, timeout=None):
+ with self.not_full:
+ if self.maxsize > 0:
+ if not block:
+ if self._qsize() >= self.maxsize:
+ raise Full
+ elif timeout is None:
+ while self._qsize() >= self.maxsize:
+ self.not_full.wait()
+ elif timeout < 0:
+ raise ValueError("'timeout' must be a non-negative number")
+ else:
+ endtime = time() + timeout
+ while self._qsize() >= self.maxsize:
+ remaining = endtime - time()
+ if remaining <= 0.0:
+ raise Full
+ self.not_full.wait(remaining)
+ self._putleft(item)
+ self.unfinished_tasks += 1
+ self.not_empty.notify()
+
+ def putleft_nowait(self, item):
+ return self.putleft(item, block=False)
+
+ def _putleft(self, item):
+ self.queue.appendleft(item)
diff --git a/common/expired_dict.py b/common/expired_dict.py
new file mode 100644
index 0000000..42fb4b1
--- /dev/null
+++ b/common/expired_dict.py
@@ -0,0 +1,42 @@
+from datetime import datetime, timedelta
+
+
+class ExpiredDict(dict):
+ def __init__(self, expires_in_seconds):
+ super().__init__()
+ self.expires_in_seconds = expires_in_seconds
+
+ def __getitem__(self, key):
+ value, expiry_time = super().__getitem__(key)
+ if datetime.now() > expiry_time:
+ del self[key]
+ raise KeyError("expired {}".format(key))
+ self.__setitem__(key, value)
+ return value
+
+ def __setitem__(self, key, value):
+ expiry_time = datetime.now() + timedelta(seconds=self.expires_in_seconds)
+ super().__setitem__(key, (value, expiry_time))
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def __contains__(self, key):
+ try:
+ self[key]
+ return True
+ except KeyError:
+ return False
+
+ def keys(self):
+ keys = list(super().keys())
+ return [key for key in keys if key in self]
+
+ def items(self):
+ return [(key, self[key]) for key in self.keys()]
+
+ def __iter__(self):
+ return self.keys().__iter__()
diff --git a/common/linkai_client.py b/common/linkai_client.py
new file mode 100644
index 0000000..ad7d213
--- /dev/null
+++ b/common/linkai_client.py
@@ -0,0 +1,55 @@
+from bridge.context import Context, ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from linkai import LinkAIClient, PushMsg
+from config import conf, pconf, plugin_config
+from plugins import PluginManager
+
+
+chat_client: LinkAIClient
+
+class ChatClient(LinkAIClient):
+ def __init__(self, api_key, host, channel):
+ super().__init__(api_key, host)
+ self.channel = channel
+ self.client_type = channel.channel_type
+
+ def on_message(self, push_msg: PushMsg):
+ session_id = push_msg.session_id
+ msg_content = push_msg.msg_content
+ logger.info(f"receive msg push, session_id={session_id}, msg_content={msg_content}")
+ context = Context()
+ context.type = ContextType.TEXT
+ context["receiver"] = session_id
+ context["isgroup"] = push_msg.is_group
+ self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
+
+ def on_config(self, config: dict):
+ if not self.client_id:
+ return
+ logger.info(f"从控制台加载配置: {config}")
+ local_config = conf()
+ for key in local_config.keys():
+ if config.get(key) is not None:
+ local_config[key] = config.get(key)
+ if config.get("reply_voice_mode"):
+ if config.get("reply_voice_mode") == "voice_reply_voice":
+ local_config["voice_reply_voice"] = True
+ elif config.get("reply_voice_mode") == "always_reply_voice":
+ local_config["always_reply_voice"] = True
+ # if config.get("admin_password") and plugin_config["Godcmd"]:
+ # plugin_config["Godcmd"]["password"] = config.get("admin_password")
+ # PluginManager().instances["Godcmd"].reload()
+ # if config.get("group_app_map") and pconf("linkai"):
+ # local_group_map = {}
+ # for mapping in config.get("group_app_map"):
+ # local_group_map[mapping.get("group_name")] = mapping.get("app_code")
+ # pconf("linkai")["group_app_map"] = local_group_map
+ # PluginManager().instances["linkai"].reload()
+
+
+def start(channel):
+ global chat_client
+ chat_client = ChatClient(api_key=conf().get("linkai_api_key"),
+ host="link-ai.chat", channel=channel)
+ chat_client.start()
diff --git a/common/log.py b/common/log.py
new file mode 100644
index 0000000..f02a365
--- /dev/null
+++ b/common/log.py
@@ -0,0 +1,38 @@
+import logging
+import sys
+
+
+def _reset_logger(log):
+ for handler in log.handlers:
+ handler.close()
+ log.removeHandler(handler)
+ del handler
+ log.handlers.clear()
+ log.propagate = False
+ console_handle = logging.StreamHandler(sys.stdout)
+ console_handle.setFormatter(
+ logging.Formatter(
+ "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+ )
+ file_handle = logging.FileHandler("run.log", encoding="utf-8")
+ file_handle.setFormatter(
+ logging.Formatter(
+ "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+ )
+ log.addHandler(file_handle)
+ log.addHandler(console_handle)
+
+
+def _get_logger():
+ log = logging.getLogger("log")
+ _reset_logger(log)
+ log.setLevel(logging.INFO)
+ return log
+
+
+# 日志句柄
+logger = _get_logger()
diff --git a/common/memory.py b/common/memory.py
new file mode 100644
index 0000000..026bed2
--- /dev/null
+++ b/common/memory.py
@@ -0,0 +1,3 @@
+from common.expired_dict import ExpiredDict
+
+USER_IMAGE_CACHE = ExpiredDict(60 * 3)
\ No newline at end of file
diff --git a/common/package_manager.py b/common/package_manager.py
new file mode 100644
index 0000000..8f1aa45
--- /dev/null
+++ b/common/package_manager.py
@@ -0,0 +1,36 @@
+import time
+
+import pip
+from pip._internal import main as pipmain
+
+from common.log import _reset_logger, logger
+
+
+def install(package):
+ pipmain(["install", package])
+
+
+def install_requirements(file):
+ pipmain(["install", "-r", file, "--upgrade"])
+ _reset_logger(logger)
+
+
+def check_dulwich():
+ needwait = False
+ for i in range(2):
+ if needwait:
+ time.sleep(3)
+ needwait = False
+ try:
+ import dulwich
+
+ return
+ except ImportError:
+ try:
+ install("dulwich")
+ except:
+ needwait = True
+ try:
+ import dulwich
+ except ImportError:
+ raise ImportError("Unable to import dulwich")
diff --git a/common/singleton.py b/common/singleton.py
new file mode 100644
index 0000000..b46095c
--- /dev/null
+++ b/common/singleton.py
@@ -0,0 +1,9 @@
+def singleton(cls):
+ instances = {}
+
+ def get_instance(*args, **kwargs):
+ if cls not in instances:
+ instances[cls] = cls(*args, **kwargs)
+ return instances[cls]
+
+ return get_instance
diff --git a/common/sorted_dict.py b/common/sorted_dict.py
new file mode 100644
index 0000000..7a1e85b
--- /dev/null
+++ b/common/sorted_dict.py
@@ -0,0 +1,65 @@
+import heapq
+
+
+class SortedDict(dict):
+ def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False):
+ if init_dict is None:
+ init_dict = []
+ if isinstance(init_dict, dict):
+ init_dict = init_dict.items()
+ self.sort_func = sort_func
+ self.sorted_keys = None
+ self.reverse = reverse
+ self.heap = []
+ for k, v in init_dict:
+ self[k] = v
+
+ def __setitem__(self, key, value):
+ if key in self:
+ super().__setitem__(key, value)
+ for i, (priority, k) in enumerate(self.heap):
+ if k == key:
+ self.heap[i] = (self.sort_func(key, value), key)
+ heapq.heapify(self.heap)
+ break
+ self.sorted_keys = None
+ else:
+ super().__setitem__(key, value)
+ heapq.heappush(self.heap, (self.sort_func(key, value), key))
+ self.sorted_keys = None
+
+ def __delitem__(self, key):
+ super().__delitem__(key)
+ for i, (priority, k) in enumerate(self.heap):
+ if k == key:
+ del self.heap[i]
+ heapq.heapify(self.heap)
+ break
+ self.sorted_keys = None
+
+ def keys(self):
+ if self.sorted_keys is None:
+ self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
+ return self.sorted_keys
+
+ def items(self):
+ if self.sorted_keys is None:
+ self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)]
+ sorted_items = [(k, self[k]) for k in self.sorted_keys]
+ return sorted_items
+
+ def _update_heap(self, key):
+ for i, (priority, k) in enumerate(self.heap):
+ if k == key:
+ new_priority = self.sort_func(key, self[key])
+ if new_priority != priority:
+ self.heap[i] = (new_priority, key)
+ heapq.heapify(self.heap)
+ self.sorted_keys = None
+ break
+
+ def __iter__(self):
+ return iter(self.keys())
+
+ def __repr__(self):
+ return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})"
diff --git a/common/time_check.py b/common/time_check.py
new file mode 100644
index 0000000..5c2dacb
--- /dev/null
+++ b/common/time_check.py
@@ -0,0 +1,42 @@
+import hashlib
+import re
+import time
+
+import config
+from common.log import logger
+
+
+def time_checker(f):
+ def _time_checker(self, *args, **kwargs):
+ _config = config.conf()
+ chat_time_module = _config.get("chat_time_module", False)
+ if chat_time_module:
+ chat_start_time = _config.get("chat_start_time", "00:00")
+ chat_stopt_time = _config.get("chat_stop_time", "24:00")
+ time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00
+
+ starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
+ stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
+ chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
+
+ # 时间格式检查
+ if not (starttime_format_check and stoptime_format_check and chat_time_check):
+ logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
+ if chat_start_time > "23:59":
+ logger.error("启动时间可能存在问题,请修改!")
+
+ # 服务时间检查
+ now_time = time.strftime("%H:%M", time.localtime())
+ if chat_start_time <= now_time <= chat_stopt_time: # 服务时间内,正常返回回答
+ f(self, *args, **kwargs)
+ return None
+ else:
+ if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置
+ f(self, *args, **kwargs)
+ else:
+ logger.info("非服务时间内,不接受访问")
+ return None
+ else:
+ f(self, *args, **kwargs) # 未开启时间模块则直接回答
+
+ return _time_checker
diff --git a/common/tmp_dir.py b/common/tmp_dir.py
new file mode 100644
index 0000000..b01880b
--- /dev/null
+++ b/common/tmp_dir.py
@@ -0,0 +1,18 @@
+import os
+import pathlib
+
+from config import conf
+
+
+class TmpDir(object):
+ """A temporary directory that is deleted when the object is destroyed."""
+
+ tmpFilePath = pathlib.Path("./tmp/")
+
+ def __init__(self):
+ pathExists = os.path.exists(self.tmpFilePath)
+ if not pathExists:
+ os.makedirs(self.tmpFilePath)
+
+ def path(self):
+ return str(self.tmpFilePath) + "/"
diff --git a/common/token_bucket.py b/common/token_bucket.py
new file mode 100644
index 0000000..23901b6
--- /dev/null
+++ b/common/token_bucket.py
@@ -0,0 +1,45 @@
+import threading
+import time
+
+
+class TokenBucket:
+ def __init__(self, tpm, timeout=None):
+ self.capacity = int(tpm) # 令牌桶容量
+ self.tokens = 0 # 初始令牌数为0
+ self.rate = int(tpm) / 60 # 令牌每秒生成速率
+ self.timeout = timeout # 等待令牌超时时间
+ self.cond = threading.Condition() # 条件变量
+ self.is_running = True
+ # 开启令牌生成线程
+ threading.Thread(target=self._generate_tokens).start()
+
+ def _generate_tokens(self):
+ """生成令牌"""
+ while self.is_running:
+ with self.cond:
+ if self.tokens < self.capacity:
+ self.tokens += 1
+ self.cond.notify() # 通知获取令牌的线程
+ time.sleep(1 / self.rate)
+
+ def get_token(self):
+ """获取令牌"""
+ with self.cond:
+ while self.tokens <= 0:
+ flag = self.cond.wait(self.timeout)
+ if not flag: # 超时
+ return False
+ self.tokens -= 1
+ return True
+
+ def close(self):
+ self.is_running = False
+
+
+if __name__ == "__main__":
+ token_bucket = TokenBucket(20, None) # 创建一个每分钟生产20个tokens的令牌桶
+ # token_bucket = TokenBucket(20, 0.1)
+ for i in range(3):
+ if token_bucket.get_token():
+ print(f"第{i+1}次请求成功")
+ token_bucket.close()
diff --git a/common/utils.py b/common/utils.py
new file mode 100644
index 0000000..dd69c9d
--- /dev/null
+++ b/common/utils.py
@@ -0,0 +1,56 @@
+import io
+import os
+from urllib.parse import urlparse
+from PIL import Image
+
+
+def fsize(file):
+ if isinstance(file, io.BytesIO):
+ return file.getbuffer().nbytes
+ elif isinstance(file, str):
+ return os.path.getsize(file)
+ elif hasattr(file, "seek") and hasattr(file, "tell"):
+ pos = file.tell()
+ file.seek(0, os.SEEK_END)
+ size = file.tell()
+ file.seek(pos)
+ return size
+ else:
+ raise TypeError("Unsupported type")
+
+
+def compress_imgfile(file, max_size):
+ if fsize(file) <= max_size:
+ return file
+ file.seek(0)
+ img = Image.open(file)
+ rgb_image = img.convert("RGB")
+ quality = 95
+ while True:
+ out_buf = io.BytesIO()
+ rgb_image.save(out_buf, "JPEG", quality=quality)
+ if fsize(out_buf) <= max_size:
+ return out_buf
+ quality -= 5
+
+
+def split_string_by_utf8_length(string, max_length, max_split=0):
+ encoded = string.encode("utf-8")
+ start, end = 0, 0
+ result = []
+ while end < len(encoded):
+ if max_split > 0 and len(result) >= max_split:
+ result.append(encoded[start:].decode("utf-8"))
+ break
+ end = min(start + max_length, len(encoded))
+ # 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
+ while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
+ end -= 1
+ result.append(encoded[start:end].decode("utf-8"))
+ start = end
+ return result
+
+
+def get_path_suffix(path):
+ path = urlparse(path).path
+ return os.path.splitext(path)[-1].lstrip('.')
diff --git a/config-template.json b/config-template.json
new file mode 100644
index 0000000..bdaadde
--- /dev/null
+++ b/config-template.json
@@ -0,0 +1,36 @@
+{
+ "channel_type": "wx",
+ "model": "",
+ "open_ai_api_key": "YOUR API KEY",
+ "text_to_image": "dall-e-2",
+ "voice_to_text": "openai",
+ "text_to_voice": "openai",
+ "proxy": "",
+ "hot_reload": false,
+ "single_chat_prefix": [
+ "bot",
+ "@bot"
+ ],
+ "single_chat_reply_prefix": "[bot] ",
+ "group_chat_prefix": [
+ "@bot"
+ ],
+ "group_name_white_list": [
+ "ChatGPT测试群",
+ "ChatGPT测试群2"
+ ],
+ "image_create_prefix": [
+ "画"
+ ],
+ "speech_recognition": true,
+ "group_speech_recognition": false,
+ "voice_reply_voice": false,
+ "conversation_max_tokens": 2500,
+ "expires_in_seconds": 3600,
+ "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
+ "temperature": 0.7,
+ "subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
+ "use_linkai": false,
+ "linkai_api_key": "",
+ "linkai_app_code": ""
+}
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..154c633
--- /dev/null
+++ b/config.py
@@ -0,0 +1,312 @@
+# encoding:utf-8
+
+import json
+import logging
+import os
+import pickle
+
+from common.log import logger
+
+# 将所有可用的配置项写在字典里, 请使用小写字母
+# 此处的配置值无实际意义,程序不会读取此处的配置,仅用于提示格式,请将配置加入到config.json中
+available_setting = {
+ # openai api配置
+ "open_ai_api_key": "", # openai api key
+ # openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
+ "open_ai_api_base": "https://api.openai.com/v1",
+ "proxy": "", # openai使用的代理
+ # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
+ "model": "gpt-3.5-turbo", # 还支持 gpt-4, gpt-4-turbo, wenxin, xunfei, qwen
+ "use_azure_chatgpt": False, # 是否使用azure的chatgpt
+ "azure_deployment_id": "", # azure 模型部署名称
+ "azure_api_version": "", # azure api版本
+ # Bot触发配置
+ "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
+ "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
+ "single_chat_reply_suffix": "", # 私聊时自动回复的后缀,\n 可以换行
+ "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
+ "group_chat_reply_prefix": "", # 群聊时自动回复的前缀
+ "group_chat_reply_suffix": "", # 群聊时自动回复的后缀,\n 可以换行
+ "group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
+ "group_at_off": False, # 是否关闭群聊时@bot的触发
+ "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
+ "group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
+ "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
+ "nick_name_black_list": [], # 用户昵称黑名单
+ "group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎
+ "trigger_by_self": False, # 是否允许机器人触发
+ "text_to_image": "dall-e-2", # 图片生成模型,可选 dall-e-2, dall-e-3
+ "image_proxy": True, # 是否需要图片代理,国内访问LinkAI时需要
+ "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
+ "concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
+ "image_create_size": "256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024 (dall-e-3默认为1024x1024)
+ "group_chat_exit_group": False,
+ # chatgpt会话参数
+ "expires_in_seconds": 3600, # 无操作会话的过期时间
+ # 人格描述
+ "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
+ "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
+ # chatgpt限流配置
+ "rate_limit_chatgpt": 20, # chatgpt的调用频率限制
+ "rate_limit_dalle": 50, # openai dalle的调用频率限制
+ # chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
+ "temperature": 0.9,
+ "top_p": 1,
+ "frequency_penalty": 0,
+ "presence_penalty": 0,
+ "request_timeout": 180, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+ "timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
+ # Baidu 文心一言参数
+ "baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
+ "baidu_wenxin_api_key": "", # Baidu api key
+ "baidu_wenxin_secret_key": "", # Baidu secret key
+ # 讯飞星火API
+ "xunfei_app_id": "", # 讯飞应用ID
+ "xunfei_api_key": "", # 讯飞 API key
+ "xunfei_api_secret": "", # 讯飞 API secret
+ # claude 配置
+ "claude_api_cookie": "",
+ "claude_uuid": "",
+ # 通义千问API, 获取方式查看文档 https://help.aliyun.com/document_detail/2587494.html
+ "qwen_access_key_id": "",
+ "qwen_access_key_secret": "",
+ "qwen_agent_key": "",
+ "qwen_app_id": "",
+ "qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
+ # Google Gemini Api Key
+ "gemini_api_key": "",
+ # wework的通用配置
+ "wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开
+ # 语音设置
+ "speech_recognition": True, # 是否开启语音识别
+ "group_speech_recognition": False, # 是否开启群组语音识别
+ "voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
+ "always_reply_voice": False, # 是否一直使用语音回复
+ "voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure
+ "text_to_voice": "openai", # 语音合成引擎,支持openai,baidu,google,pytts(offline),azure,elevenlabs
+ "text_to_voice_model": "tts-1",
+ "tts_voice_id": "alloy",
+ # baidu 语音api配置, 使用百度语音识别和语音合成时需要
+ "baidu_app_id": "",
+ "baidu_api_key": "",
+ "baidu_secret_key": "",
+ # 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
+ "baidu_dev_pid": "1536",
+ # azure 语音api配置, 使用azure语音识别和语音合成时需要
+ "azure_voice_api_key": "",
+ "azure_voice_region": "japaneast",
+ # elevenlabs 语音api配置
+ "xi_api_key": "", #获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication
+ "xi_voice_id": "", #ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam”
+ # 服务时间限制,目前支持itchat
+ "chat_time_module": False, # 是否开启服务时间限制
+ "chat_start_time": "00:00", # 服务开始时间
+ "chat_stop_time": "24:00", # 服务结束时间
+ # 翻译api
+ "translate": "baidu", # 翻译api,支持baidu
+ # baidu翻译api的配置
+ "baidu_translate_app_id": "", # 百度翻译api的appid
+ "baidu_translate_app_key": "", # 百度翻译api的秘钥
+ # itchat的配置
+ "hot_reload": False, # 是否开启热重载
+ # wechaty的配置
+ "wechaty_puppet_service_token": "", # wechaty的token
+ # wechatmp的配置
+ "wechatmp_token": "", # 微信公众平台的Token
+ "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
+ "wechatmp_app_id": "", # 微信公众平台的appID
+ "wechatmp_app_secret": "", # 微信公众平台的appsecret
+ "wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
+ # wechatcom的通用配置
+ "wechatcom_corp_id": "", # 企业微信公司的corpID
+ # wechatcomapp的配置
+ "wechatcomapp_token": "", # 企业微信app的token
+ "wechatcomapp_port": 9898, # 企业微信app的服务端口,不需要端口转发
+ "wechatcomapp_secret": "", # 企业微信app的secret
+ "wechatcomapp_agent_id": "", # 企业微信app的agent_id
+ "wechatcomapp_aes_key": "", # 企业微信app的aes_key
+
+ # 飞书配置
+ "feishu_port": 80, # 飞书bot监听端口
+ "feishu_app_id": "", # 飞书机器人应用APP Id
+ "feishu_app_secret": "", # 飞书机器人APP secret
+ "feishu_token": "", # 飞书 verification token
+ "feishu_bot_name": "", # 飞书机器人的名字
+
+ # 钉钉配置
+ "dingtalk_client_id": "", # 钉钉机器人Client ID
+ "dingtalk_client_secret": "", # 钉钉机器人Client Secret
+
+ # chatgpt指令自定义触发词
+ "clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
+ # channel配置
+ "channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app}
+ "subscribe_msg": "", # 订阅消息, 支持: wechatmp, wechatmp_service, wechatcom_app
+ "debug": False, # 是否开启debug模式,开启后会打印更多日志
+ "appdata_dir": "", # 数据目录
+ # 插件配置
+ "plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
+ # 是否使用全局插件配置
+ "use_global_plugin_config": False,
+ "max_media_send_count": 3, # 单次最大发送媒体资源的个数
+ "media_send_interval": 1, # 发送图片的事件间隔,单位秒
+ # 智谱AI 平台配置
+ "zhipu_ai_api_key": "",
+ "zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
+ # LinkAI平台配置
+ "use_linkai": False,
+ "linkai_api_key": "",
+ "linkai_app_code": "",
+ "linkai_api_base": "https://api.link-ai.chat", # linkAI服务地址,若国内无法访问或延迟较高可改为 https://api.link-ai.tech
+}
+
+
+class Config(dict):
+ def __init__(self, d=None):
+ super().__init__()
+ if d is None:
+ d = {}
+ for k, v in d.items():
+ self[k] = v
+ # user_datas: 用户数据,key为用户名,value为用户数据,也是dict
+ self.user_datas = {}
+
+ def __getitem__(self, key):
+ if key not in available_setting:
+ raise Exception("key {} not in available_setting".format(key))
+ return super().__getitem__(key)
+
+ def __setitem__(self, key, value):
+ if key not in available_setting:
+ raise Exception("key {} not in available_setting".format(key))
+ return super().__setitem__(key, value)
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError as e:
+ return default
+ except Exception as e:
+ raise e
+
+ # Make sure to return a dictionary to ensure atomic
+ def get_user_data(self, user) -> dict:
+ if self.user_datas.get(user) is None:
+ self.user_datas[user] = {}
+ return self.user_datas[user]
+
+ def load_user_datas(self):
+ try:
+ with open(os.path.join(get_appdata_dir(), "user_datas.pkl"), "rb") as f:
+ self.user_datas = pickle.load(f)
+ logger.info("[Config] User datas loaded.")
+ except FileNotFoundError as e:
+ logger.info("[Config] User datas file not found, ignore.")
+ except Exception as e:
+ logger.info("[Config] User datas error: {}".format(e))
+ self.user_datas = {}
+
+ def save_user_datas(self):
+ try:
+ with open(os.path.join(get_appdata_dir(), "user_datas.pkl"), "wb") as f:
+ pickle.dump(self.user_datas, f)
+ logger.info("[Config] User datas saved.")
+ except Exception as e:
+ logger.info("[Config] User datas error: {}".format(e))
+
+
+config = Config()
+
+
+def load_config():
+ global config
+ config_path = "./config.json"
+ if not os.path.exists(config_path):
+ logger.info("配置文件不存在,将使用config-template.json模板")
+ config_path = "./config-template.json"
+
+ config_str = read_file(config_path)
+ logger.debug("[INIT] config str: {}".format(config_str))
+
+ # 将json字符串反序列化为dict类型
+ config = Config(json.loads(config_str))
+
+ # override config with environment variables.
+ # Some online deployment platforms (e.g. Railway) deploy project from github directly. So you shouldn't put your secrets like api key in a config file, instead use environment variables to override the default config.
+ for name, value in os.environ.items():
+ name = name.lower()
+ if name in available_setting:
+ logger.info("[INIT] override config by environ args: {}={}".format(name, value))
+ try:
+ config[name] = eval(value)
+ except:
+ if value == "false":
+ config[name] = False
+ elif value == "true":
+ config[name] = True
+ else:
+ config[name] = value
+
+ if config.get("debug", False):
+ logger.setLevel(logging.DEBUG)
+ logger.debug("[INIT] set log level to DEBUG")
+
+ logger.info("[INIT] load config: {}".format(config))
+
+ config.load_user_datas()
+
+
+def get_root():
+ return os.path.dirname(os.path.abspath(__file__))
+
+
+def read_file(path):
+ with open(path, mode="r", encoding="utf-8") as f:
+ return f.read()
+
+
+def conf():
+ return config
+
+
+def get_appdata_dir():
+ data_path = os.path.join(get_root(), conf().get("appdata_dir", ""))
+ if not os.path.exists(data_path):
+ logger.info("[INIT] data path not exists, create it: {}".format(data_path))
+ os.makedirs(data_path)
+ return data_path
+
+
+def subscribe_msg():
+ trigger_prefix = conf().get("single_chat_prefix", [""])[0]
+ msg = conf().get("subscribe_msg", "")
+ return msg.format(trigger_prefix=trigger_prefix)
+
+
+# global plugin config
+plugin_config = {}
+
+
+def write_plugin_config(pconf: dict):
+ """
+ 写入插件全局配置
+ :param pconf: 全量插件配置
+ """
+ global plugin_config
+ for k in pconf:
+ plugin_config[k.lower()] = pconf[k]
+
+
+def pconf(plugin_name: str) -> dict:
+ """
+ 根据插件名称获取配置
+ :param plugin_name: 插件名称
+ :return: 该插件的配置项
+ """
+ return plugin_config.get(plugin_name.lower())
+
+
+# 全局配置,用于存放全局生效的状态
+global_config = {
+ "admin_users": []
+}
diff --git a/docker/Dockerfile.latest b/docker/Dockerfile.latest
new file mode 100644
index 0000000..515ad3f
--- /dev/null
+++ b/docker/Dockerfile.latest
@@ -0,0 +1,35 @@
+FROM python:3.10-slim-bullseye
+
+LABEL maintainer="foo@bar.com"
+ARG TZ='Asia/Shanghai'
+
+ARG CHATGPT_ON_WECHAT_VER
+
+RUN echo /etc/apt/sources.list
+# RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list
+ENV BUILD_PREFIX=/app
+
+ADD . ${BUILD_PREFIX}
+
+RUN apt-get update \
+ &&apt-get install -y --no-install-recommends bash ffmpeg espeak libavcodec-extra\
+ && cd ${BUILD_PREFIX} \
+ && cp config-template.json config.json \
+ && /usr/local/bin/python -m pip install --no-cache --upgrade pip \
+ && pip install --no-cache -r requirements.txt \
+ && pip install --no-cache -r requirements-optional.txt \
+ && pip install azure-cognitiveservices-speech
+
+WORKDIR ${BUILD_PREFIX}
+
+ADD docker/entrypoint.sh /entrypoint.sh
+
+RUN chmod +x /entrypoint.sh \
+ && mkdir -p /home/noroot \
+ && groupadd -r noroot \
+ && useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
+ && chown -R noroot:noroot /home/noroot ${BUILD_PREFIX} /usr/local/lib
+
+USER noroot
+
+ENTRYPOINT ["/entrypoint.sh"]
diff --git a/docker/build.latest.sh b/docker/build.latest.sh
new file mode 100644
index 0000000..92c3564
--- /dev/null
+++ b/docker/build.latest.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+unset KUBECONFIG
+
+cd .. && docker build -f docker/Dockerfile.latest \
+ -t zhayujie/chatgpt-on-wechat .
+
+docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$(date +%y%m%d)
\ No newline at end of file
diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml
new file mode 100644
index 0000000..8dbb1e4
--- /dev/null
+++ b/docker/docker-compose.yml
@@ -0,0 +1,24 @@
+version: '2.0'
+services:
+ chatgpt-on-wechat:
+ image: zhayujie/chatgpt-on-wechat
+ container_name: chatgpt-on-wechat
+ security_opt:
+ - seccomp:unconfined
+ environment:
+ OPEN_AI_API_KEY: 'YOUR API KEY'
+ MODEL: 'gpt-3.5-turbo'
+ PROXY: ''
+ SINGLE_CHAT_PREFIX: '["bot", "@bot"]'
+ SINGLE_CHAT_REPLY_PREFIX: '"[bot] "'
+ GROUP_CHAT_PREFIX: '["@bot"]'
+ GROUP_NAME_WHITE_LIST: '["ChatGPT测试群", "ChatGPT测试群2"]'
+ IMAGE_CREATE_PREFIX: '["画", "看", "找"]'
+ CONVERSATION_MAX_TOKENS: 1000
+ SPEECH_RECOGNITION: 'False'
+ CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
+ EXPIRES_IN_SECONDS: 3600
+ USE_GLOBAL_PLUGIN_CONFIG: 'True'
+ USE_LINKAI: 'False'
+ LINKAI_API_KEY: ''
+ LINKAI_APP_CODE: ''
diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh
new file mode 100755
index 0000000..f7f4cfa
--- /dev/null
+++ b/docker/entrypoint.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+set -e
+
+# build prefix
+CHATGPT_ON_WECHAT_PREFIX=${CHATGPT_ON_WECHAT_PREFIX:-""}
+# path to config.json
+CHATGPT_ON_WECHAT_CONFIG_PATH=${CHATGPT_ON_WECHAT_CONFIG_PATH:-""}
+# execution command line
+CHATGPT_ON_WECHAT_EXEC=${CHATGPT_ON_WECHAT_EXEC:-""}
+
+# use environment variables to pass parameters
+# if you have not defined environment variables, set them below
+# export OPEN_AI_API_KEY=${OPEN_AI_API_KEY:-'YOUR API KEY'}
+# export OPEN_AI_PROXY=${OPEN_AI_PROXY:-""}
+# export SINGLE_CHAT_PREFIX=${SINGLE_CHAT_PREFIX:-'["bot", "@bot"]'}
+# export SINGLE_CHAT_REPLY_PREFIX=${SINGLE_CHAT_REPLY_PREFIX:-'"[bot] "'}
+# export GROUP_CHAT_PREFIX=${GROUP_CHAT_PREFIX:-'["@bot"]'}
+# export GROUP_NAME_WHITE_LIST=${GROUP_NAME_WHITE_LIST:-'["ChatGPT测试群", "ChatGPT测试群2"]'}
+# export IMAGE_CREATE_PREFIX=${IMAGE_CREATE_PREFIX:-'["画", "看", "找"]'}
+# export CONVERSATION_MAX_TOKENS=${CONVERSATION_MAX_TOKENS:-"1000"}
+# export SPEECH_RECOGNITION=${SPEECH_RECOGNITION:-"False"}
+# export CHARACTER_DESC=${CHARACTER_DESC:-"你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"}
+# export EXPIRES_IN_SECONDS=${EXPIRES_IN_SECONDS:-"3600"}
+
+# CHATGPT_ON_WECHAT_PREFIX is empty, use /app
+if [ "$CHATGPT_ON_WECHAT_PREFIX" == "" ] ; then
+ CHATGPT_ON_WECHAT_PREFIX=/app
+fi
+
+# CHATGPT_ON_WECHAT_CONFIG_PATH is empty, use '/app/config.json'
+if [ "$CHATGPT_ON_WECHAT_CONFIG_PATH" == "" ] ; then
+ CHATGPT_ON_WECHAT_CONFIG_PATH=$CHATGPT_ON_WECHAT_PREFIX/config.json
+fi
+
+# CHATGPT_ON_WECHAT_EXEC is empty, use ‘python app.py’
+if [ "$CHATGPT_ON_WECHAT_EXEC" == "" ] ; then
+ CHATGPT_ON_WECHAT_EXEC="python app.py"
+fi
+
+# modify content in config.json
+# if [ "$OPEN_AI_API_KEY" == "YOUR API KEY" ] || [ "$OPEN_AI_API_KEY" == "" ]; then
+# echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m"
+# fi
+
+
+# go to prefix dir
+cd $CHATGPT_ON_WECHAT_PREFIX
+# excute
+$CHATGPT_ON_WECHAT_EXEC
+
+
diff --git a/docs/images/aigcopen.png b/docs/images/aigcopen.png
new file mode 100644
index 0000000..76a20c6
Binary files /dev/null and b/docs/images/aigcopen.png differ
diff --git a/docs/images/contact.jpg b/docs/images/contact.jpg
new file mode 100644
index 0000000..3a8a412
Binary files /dev/null and b/docs/images/contact.jpg differ
diff --git a/docs/images/group-chat-sample.jpg b/docs/images/group-chat-sample.jpg
new file mode 100644
index 0000000..35fffda
Binary files /dev/null and b/docs/images/group-chat-sample.jpg differ
diff --git a/docs/images/image-create-sample.jpg b/docs/images/image-create-sample.jpg
new file mode 100644
index 0000000..5d916c5
Binary files /dev/null and b/docs/images/image-create-sample.jpg differ
diff --git a/docs/images/planet.jpg b/docs/images/planet.jpg
new file mode 100644
index 0000000..dffca7f
Binary files /dev/null and b/docs/images/planet.jpg differ
diff --git a/docs/images/single-chat-sample.jpg b/docs/images/single-chat-sample.jpg
new file mode 100644
index 0000000..f24b74d
Binary files /dev/null and b/docs/images/single-chat-sample.jpg differ
diff --git a/lib/itchat/LICENSE b/lib/itchat/LICENSE
new file mode 100644
index 0000000..ba1a0e2
--- /dev/null
+++ b/lib/itchat/LICENSE
@@ -0,0 +1,9 @@
+**The MIT License (MIT)**
+
+Copyright (c) 2017 LittleCoder ([littlecodersh@Github](https://github.com/littlecodersh))
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/lib/itchat/__init__.py b/lib/itchat/__init__.py
new file mode 100644
index 0000000..cccbdef
--- /dev/null
+++ b/lib/itchat/__init__.py
@@ -0,0 +1,96 @@
+from .core import Core
+from .config import VERSION, ASYNC_COMPONENTS
+from .log import set_logging
+
+if ASYNC_COMPONENTS:
+ from .async_components import load_components
+else:
+ from .components import load_components
+
+
+__version__ = VERSION
+
+
+instanceList = []
+
+def load_async_itchat() -> Core:
+ """load async-based itchat instance
+
+ Returns:
+ Core: the abstract interface of itchat
+ """
+ from .async_components import load_components
+ load_components(Core)
+ return Core()
+
+
+def load_sync_itchat() -> Core:
+ """load sync-based itchat instance
+
+ Returns:
+ Core: the abstract interface of itchat
+ """
+ from .components import load_components
+ load_components(Core)
+ return Core()
+
+
+if ASYNC_COMPONENTS:
+ instance = load_async_itchat()
+else:
+ instance = load_sync_itchat()
+
+
+instanceList = [instance]
+
+# I really want to use sys.modules[__name__] = originInstance
+# but it makes auto-fill a real mess, so forgive me for my following **
+# actually it toke me less than 30 seconds, god bless Uganda
+
+# components.login
+login = instance.login
+get_QRuuid = instance.get_QRuuid
+get_QR = instance.get_QR
+check_login = instance.check_login
+web_init = instance.web_init
+show_mobile_login = instance.show_mobile_login
+start_receiving = instance.start_receiving
+get_msg = instance.get_msg
+logout = instance.logout
+# components.contact
+update_chatroom = instance.update_chatroom
+update_friend = instance.update_friend
+get_contact = instance.get_contact
+get_friends = instance.get_friends
+get_chatrooms = instance.get_chatrooms
+get_mps = instance.get_mps
+set_alias = instance.set_alias
+set_pinned = instance.set_pinned
+accept_friend = instance.accept_friend
+get_head_img = instance.get_head_img
+create_chatroom = instance.create_chatroom
+set_chatroom_name = instance.set_chatroom_name
+delete_member_from_chatroom = instance.delete_member_from_chatroom
+add_member_into_chatroom = instance.add_member_into_chatroom
+# components.messages
+send_raw_msg = instance.send_raw_msg
+send_msg = instance.send_msg
+upload_file = instance.upload_file
+send_file = instance.send_file
+send_image = instance.send_image
+send_video = instance.send_video
+send = instance.send
+revoke = instance.revoke
+# components.hotreload
+dump_login_status = instance.dump_login_status
+load_login_status = instance.load_login_status
+# components.register
+auto_login = instance.auto_login
+configured_reply = instance.configured_reply
+msg_register = instance.msg_register
+run = instance.run
+# other functions
+search_friends = instance.search_friends
+search_chatrooms = instance.search_chatrooms
+search_mps = instance.search_mps
+set_logging = set_logging
diff --git a/lib/itchat/async_components/__init__.py b/lib/itchat/async_components/__init__.py
new file mode 100644
index 0000000..0fc321c
--- /dev/null
+++ b/lib/itchat/async_components/__init__.py
@@ -0,0 +1,12 @@
+from .contact import load_contact
+from .hotreload import load_hotreload
+from .login import load_login
+from .messages import load_messages
+from .register import load_register
+
+def load_components(core):
+ load_contact(core)
+ load_hotreload(core)
+ load_login(core)
+ load_messages(core)
+ load_register(core)
diff --git a/lib/itchat/async_components/contact.py b/lib/itchat/async_components/contact.py
new file mode 100644
index 0000000..440c288
--- /dev/null
+++ b/lib/itchat/async_components/contact.py
@@ -0,0 +1,488 @@
+import time, re, io
+import json, copy
+import logging
+
+from .. import config, utils
+from ..components.contact import accept_friend
+from ..returnvalues import ReturnValue
+from ..storage import contact_change
+from ..utils import update_info_dict
+
+logger = logging.getLogger('itchat')
+
+def load_contact(core):
+ core.update_chatroom = update_chatroom
+ core.update_friend = update_friend
+ core.get_contact = get_contact
+ core.get_friends = get_friends
+ core.get_chatrooms = get_chatrooms
+ core.get_mps = get_mps
+ core.set_alias = set_alias
+ core.set_pinned = set_pinned
+ core.accept_friend = accept_friend
+ core.get_head_img = get_head_img
+ core.create_chatroom = create_chatroom
+ core.set_chatroom_name = set_chatroom_name
+ core.delete_member_from_chatroom = delete_member_from_chatroom
+ core.add_member_into_chatroom = add_member_into_chatroom
+
+def update_chatroom(self, userName, detailedMember=False):
+ if not isinstance(userName, list):
+ userName = [userName]
+ url = '%s/webwxbatchgetcontact?type=ex&r=%s' % (
+ self.loginInfo['url'], int(time.time()))
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT }
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Count': len(userName),
+ 'List': [{
+ 'UserName': u,
+ 'ChatRoomId': '', } for u in userName], }
+ chatroomList = json.loads(self.s.post(url, data=json.dumps(data), headers=headers
+ ).content.decode('utf8', 'replace')).get('ContactList')
+ if not chatroomList:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No chatroom found',
+ 'Ret': -1001, }})
+
+ if detailedMember:
+ def get_detailed_member_info(encryChatroomId, memberList):
+ url = '%s/webwxbatchgetcontact?type=ex&r=%s' % (
+ self.loginInfo['url'], int(time.time()))
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT, }
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Count': len(memberList),
+ 'List': [{
+ 'UserName': member['UserName'],
+ 'EncryChatRoomId': encryChatroomId} \
+ for member in memberList], }
+ return json.loads(self.s.post(url, data=json.dumps(data), headers=headers
+ ).content.decode('utf8', 'replace'))['ContactList']
+ MAX_GET_NUMBER = 50
+ for chatroom in chatroomList:
+ totalMemberList = []
+ for i in range(int(len(chatroom['MemberList']) / MAX_GET_NUMBER + 1)):
+ memberList = chatroom['MemberList'][i*MAX_GET_NUMBER: (i+1)*MAX_GET_NUMBER]
+ totalMemberList += get_detailed_member_info(chatroom['EncryChatRoomId'], memberList)
+ chatroom['MemberList'] = totalMemberList
+
+ update_local_chatrooms(self, chatroomList)
+ r = [self.storageClass.search_chatrooms(userName=c['UserName'])
+ for c in chatroomList]
+ return r if 1 < len(r) else r[0]
+
+def update_friend(self, userName):
+ if not isinstance(userName, list):
+ userName = [userName]
+ url = '%s/webwxbatchgetcontact?type=ex&r=%s' % (
+ self.loginInfo['url'], int(time.time()))
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT }
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Count': len(userName),
+ 'List': [{
+ 'UserName': u,
+ 'EncryChatRoomId': '', } for u in userName], }
+ friendList = json.loads(self.s.post(url, data=json.dumps(data), headers=headers
+ ).content.decode('utf8', 'replace')).get('ContactList')
+
+ update_local_friends(self, friendList)
+ r = [self.storageClass.search_friends(userName=f['UserName'])
+ for f in friendList]
+ return r if len(r) != 1 else r[0]
+
+@contact_change
+def update_local_chatrooms(core, l):
+ '''
+ get a list of chatrooms for updating local chatrooms
+ return a list of given chatrooms with updated info
+ '''
+ for chatroom in l:
+ # format new chatrooms
+ utils.emoji_formatter(chatroom, 'NickName')
+ for member in chatroom['MemberList']:
+ if 'NickName' in member:
+ utils.emoji_formatter(member, 'NickName')
+ if 'DisplayName' in member:
+ utils.emoji_formatter(member, 'DisplayName')
+ if 'RemarkName' in member:
+ utils.emoji_formatter(member, 'RemarkName')
+ # update it to old chatrooms
+ oldChatroom = utils.search_dict_list(
+ core.chatroomList, 'UserName', chatroom['UserName'])
+ if oldChatroom:
+ update_info_dict(oldChatroom, chatroom)
+ # - update other values
+ memberList = chatroom.get('MemberList', [])
+ oldMemberList = oldChatroom['MemberList']
+ if memberList:
+ for member in memberList:
+ oldMember = utils.search_dict_list(
+ oldMemberList, 'UserName', member['UserName'])
+ if oldMember:
+ update_info_dict(oldMember, member)
+ else:
+ oldMemberList.append(member)
+ else:
+ core.chatroomList.append(chatroom)
+ oldChatroom = utils.search_dict_list(
+ core.chatroomList, 'UserName', chatroom['UserName'])
+ # delete useless members
+ if len(chatroom['MemberList']) != len(oldChatroom['MemberList']) and \
+ chatroom['MemberList']:
+ existsUserNames = [member['UserName'] for member in chatroom['MemberList']]
+ delList = []
+ for i, member in enumerate(oldChatroom['MemberList']):
+ if member['UserName'] not in existsUserNames:
+ delList.append(i)
+ delList.sort(reverse=True)
+ for i in delList:
+ del oldChatroom['MemberList'][i]
+ # - update OwnerUin
+ if oldChatroom.get('ChatRoomOwner') and oldChatroom.get('MemberList'):
+ owner = utils.search_dict_list(oldChatroom['MemberList'],
+ 'UserName', oldChatroom['ChatRoomOwner'])
+ oldChatroom['OwnerUin'] = (owner or {}).get('Uin', 0)
+ # - update IsAdmin
+ if 'OwnerUin' in oldChatroom and oldChatroom['OwnerUin'] != 0:
+ oldChatroom['IsAdmin'] = \
+ oldChatroom['OwnerUin'] == int(core.loginInfo['wxuin'])
+ else:
+ oldChatroom['IsAdmin'] = None
+ # - update Self
+ newSelf = utils.search_dict_list(oldChatroom['MemberList'],
+ 'UserName', core.storageClass.userName)
+ oldChatroom['Self'] = newSelf or copy.deepcopy(core.loginInfo['User'])
+ return {
+ 'Type' : 'System',
+ 'Text' : [chatroom['UserName'] for chatroom in l],
+ 'SystemInfo' : 'chatrooms',
+ 'FromUserName' : core.storageClass.userName,
+ 'ToUserName' : core.storageClass.userName, }
+
+@contact_change
+def update_local_friends(core, l):
+ '''
+ get a list of friends or mps for updating local contact
+ '''
+ fullList = core.memberList + core.mpList
+ for friend in l:
+ if 'NickName' in friend:
+ utils.emoji_formatter(friend, 'NickName')
+ if 'DisplayName' in friend:
+ utils.emoji_formatter(friend, 'DisplayName')
+ if 'RemarkName' in friend:
+ utils.emoji_formatter(friend, 'RemarkName')
+ oldInfoDict = utils.search_dict_list(
+ fullList, 'UserName', friend['UserName'])
+ if oldInfoDict is None:
+ oldInfoDict = copy.deepcopy(friend)
+ if oldInfoDict['VerifyFlag'] & 8 == 0:
+ core.memberList.append(oldInfoDict)
+ else:
+ core.mpList.append(oldInfoDict)
+ else:
+ update_info_dict(oldInfoDict, friend)
+
+@contact_change
+def update_local_uin(core, msg):
+ '''
+ content contains uins and StatusNotifyUserName contains username
+ they are in same order, so what I do is to pair them together
+
+ I caught an exception in this method while not knowing why
+ but don't worry, it won't cause any problem
+ '''
+ uins = re.search('([^<]*?)<', msg['Content'])
+ usernameChangedList = []
+ r = {
+ 'Type': 'System',
+ 'Text': usernameChangedList,
+ 'SystemInfo': 'uins', }
+ if uins:
+ uins = uins.group(1).split(',')
+ usernames = msg['StatusNotifyUserName'].split(',')
+ if 0 < len(uins) == len(usernames):
+ for uin, username in zip(uins, usernames):
+ if not '@' in username: continue
+ fullContact = core.memberList + core.chatroomList + core.mpList
+ userDicts = utils.search_dict_list(fullContact,
+ 'UserName', username)
+ if userDicts:
+ if userDicts.get('Uin', 0) == 0:
+ userDicts['Uin'] = uin
+ usernameChangedList.append(username)
+ logger.debug('Uin fetched: %s, %s' % (username, uin))
+ else:
+ if userDicts['Uin'] != uin:
+ logger.debug('Uin changed: %s, %s' % (
+ userDicts['Uin'], uin))
+ else:
+ if '@@' in username:
+ core.storageClass.updateLock.release()
+ update_chatroom(core, username)
+ core.storageClass.updateLock.acquire()
+ newChatroomDict = utils.search_dict_list(
+ core.chatroomList, 'UserName', username)
+ if newChatroomDict is None:
+ newChatroomDict = utils.struct_friend_info({
+ 'UserName': username,
+ 'Uin': uin,
+ 'Self': copy.deepcopy(core.loginInfo['User'])})
+ core.chatroomList.append(newChatroomDict)
+ else:
+ newChatroomDict['Uin'] = uin
+ elif '@' in username:
+ core.storageClass.updateLock.release()
+ update_friend(core, username)
+ core.storageClass.updateLock.acquire()
+ newFriendDict = utils.search_dict_list(
+ core.memberList, 'UserName', username)
+ if newFriendDict is None:
+ newFriendDict = utils.struct_friend_info({
+ 'UserName': username,
+ 'Uin': uin, })
+ core.memberList.append(newFriendDict)
+ else:
+ newFriendDict['Uin'] = uin
+ usernameChangedList.append(username)
+ logger.debug('Uin fetched: %s, %s' % (username, uin))
+ else:
+ logger.debug('Wrong length of uins & usernames: %s, %s' % (
+ len(uins), len(usernames)))
+ else:
+ logger.debug('No uins in 51 message')
+ logger.debug(msg['Content'])
+ return r
+
+def get_contact(self, update=False):
+ if not update:
+ return utils.contact_deep_copy(self, self.chatroomList)
+ def _get_contact(seq=0):
+ url = '%s/webwxgetcontact?r=%s&seq=%s&skey=%s' % (self.loginInfo['url'],
+ int(time.time()), seq, self.loginInfo['skey'])
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT, }
+ try:
+ r = self.s.get(url, headers=headers)
+ except:
+ logger.info('Failed to fetch contact, that may because of the amount of your chatrooms')
+ for chatroom in self.get_chatrooms():
+ self.update_chatroom(chatroom['UserName'], detailedMember=True)
+ return 0, []
+ j = json.loads(r.content.decode('utf-8', 'replace'))
+ return j.get('Seq', 0), j.get('MemberList')
+ seq, memberList = 0, []
+ while 1:
+ seq, batchMemberList = _get_contact(seq)
+ memberList.extend(batchMemberList)
+ if seq == 0:
+ break
+ chatroomList, otherList = [], []
+ for m in memberList:
+ if m['Sex'] != 0:
+ otherList.append(m)
+ elif '@@' in m['UserName']:
+ chatroomList.append(m)
+ elif '@' in m['UserName']:
+ # mp will be dealt in update_local_friends as well
+ otherList.append(m)
+ if chatroomList:
+ update_local_chatrooms(self, chatroomList)
+ if otherList:
+ update_local_friends(self, otherList)
+ return utils.contact_deep_copy(self, chatroomList)
+
+def get_friends(self, update=False):
+ if update:
+ self.get_contact(update=True)
+ return utils.contact_deep_copy(self, self.memberList)
+
+def get_chatrooms(self, update=False, contactOnly=False):
+ if contactOnly:
+ return self.get_contact(update=True)
+ else:
+ if update:
+ self.get_contact(True)
+ return utils.contact_deep_copy(self, self.chatroomList)
+
+def get_mps(self, update=False):
+ if update: self.get_contact(update=True)
+ return utils.contact_deep_copy(self, self.mpList)
+
+def set_alias(self, userName, alias):
+ oldFriendInfo = utils.search_dict_list(
+ self.memberList, 'UserName', userName)
+ if oldFriendInfo is None:
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1001, }})
+ url = '%s/webwxoplog?lang=%s&pass_ticket=%s' % (
+ self.loginInfo['url'], 'zh_CN', self.loginInfo['pass_ticket'])
+ data = {
+ 'UserName' : userName,
+ 'CmdId' : 2,
+ 'RemarkName' : alias,
+ 'BaseRequest' : self.loginInfo['BaseRequest'], }
+ headers = { 'User-Agent' : config.USER_AGENT}
+ r = self.s.post(url, json.dumps(data, ensure_ascii=False).encode('utf8'),
+ headers=headers)
+ r = ReturnValue(rawResponse=r)
+ if r:
+ oldFriendInfo['RemarkName'] = alias
+ return r
+
+def set_pinned(self, userName, isPinned=True):
+ url = '%s/webwxoplog?pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'])
+ data = {
+ 'UserName' : userName,
+ 'CmdId' : 3,
+ 'OP' : int(isPinned),
+ 'BaseRequest' : self.loginInfo['BaseRequest'], }
+ headers = { 'User-Agent' : config.USER_AGENT}
+ r = self.s.post(url, json=data, headers=headers)
+ return ReturnValue(rawResponse=r)
+
+def accept_friend(self, userName, v4= '', autoUpdate=True):
+ url = f"{self.loginInfo['url']}/webwxverifyuser?r={int(time.time())}&pass_ticket={self.loginInfo['pass_ticket']}"
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Opcode': 3, # 3
+ 'VerifyUserListSize': 1,
+ 'VerifyUserList': [{
+ 'Value': userName,
+ 'VerifyUserTicket': v4, }],
+ 'VerifyContent': '',
+ 'SceneListCount': 1,
+ 'SceneList': [33],
+ 'skey': self.loginInfo['skey'], }
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8', 'replace'))
+ if autoUpdate:
+ self.update_friend(userName)
+ return ReturnValue(rawResponse=r)
+
+def get_head_img(self, userName=None, chatroomUserName=None, picDir=None):
+ ''' get head image
+ * if you want to get chatroom header: only set chatroomUserName
+ * if you want to get friend header: only set userName
+ * if you want to get chatroom member header: set both
+ '''
+ params = {
+ 'userName': userName or chatroomUserName or self.storageClass.userName,
+ 'skey': self.loginInfo['skey'],
+ 'type': 'big', }
+ url = '%s/webwxgeticon' % self.loginInfo['url']
+ if chatroomUserName is None:
+ infoDict = self.storageClass.search_friends(userName=userName)
+ if infoDict is None:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No friend found',
+ 'Ret': -1001, }})
+ else:
+ if userName is None:
+ url = '%s/webwxgetheadimg' % self.loginInfo['url']
+ else:
+ chatroom = self.storageClass.search_chatrooms(userName=chatroomUserName)
+ if chatroomUserName is None:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No chatroom found',
+ 'Ret': -1001, }})
+ if 'EncryChatRoomId' in chatroom:
+ params['chatroomid'] = chatroom['EncryChatRoomId']
+ params['chatroomid'] = params.get('chatroomid') or chatroom['UserName']
+ headers = { 'User-Agent' : config.USER_AGENT}
+ r = self.s.get(url, params=params, stream=True, headers=headers)
+ tempStorage = io.BytesIO()
+ for block in r.iter_content(1024):
+ tempStorage.write(block)
+ if picDir is None:
+ return tempStorage.getvalue()
+ with open(picDir, 'wb') as f:
+ f.write(tempStorage.getvalue())
+ tempStorage.seek(0)
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Successfully downloaded',
+ 'Ret': 0, },
+ 'PostFix': utils.get_image_postfix(tempStorage.read(20)), })
+
+def create_chatroom(self, memberList, topic=''):
+ url = '%s/webwxcreatechatroom?pass_ticket=%s&r=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'], int(time.time()))
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'MemberCount': len(memberList.split(',')),
+ 'MemberList': [{'UserName': member} for member in memberList.split(',')],
+ 'Topic': topic, }
+ headers = {
+ 'content-type': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8', 'ignore'))
+ return ReturnValue(rawResponse=r)
+
+def set_chatroom_name(self, chatroomUserName, name):
+ url = '%s/webwxupdatechatroom?fun=modtopic&pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'])
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'ChatRoomName': chatroomUserName,
+ 'NewTopic': name, }
+ headers = {
+ 'content-type': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8', 'ignore'))
+ return ReturnValue(rawResponse=r)
+
+def delete_member_from_chatroom(self, chatroomUserName, memberList):
+ url = '%s/webwxupdatechatroom?fun=delmember&pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'])
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'ChatRoomName': chatroomUserName,
+ 'DelMemberList': ','.join([member['UserName'] for member in memberList]), }
+ headers = {
+ 'content-type': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT}
+ r = self.s.post(url, data=json.dumps(data),headers=headers)
+ return ReturnValue(rawResponse=r)
+
+def add_member_into_chatroom(self, chatroomUserName, memberList,
+ useInvitation=False):
+ ''' add or invite member into chatroom
+ * there are two ways to get members into chatroom: invite or directly add
+ * but for chatrooms with more than 40 users, you can only use invite
+ * but don't worry we will auto-force userInvitation for you when necessary
+ '''
+ if not useInvitation:
+ chatroom = self.storageClass.search_chatrooms(userName=chatroomUserName)
+ if not chatroom: chatroom = self.update_chatroom(chatroomUserName)
+ if len(chatroom['MemberList']) > self.loginInfo['InviteStartCount']:
+ useInvitation = True
+ if useInvitation:
+ fun, memberKeyName = 'invitemember', 'InviteMemberList'
+ else:
+ fun, memberKeyName = 'addmember', 'AddMemberList'
+ url = '%s/webwxupdatechatroom?fun=%s&pass_ticket=%s' % (
+ self.loginInfo['url'], fun, self.loginInfo['pass_ticket'])
+ params = {
+ 'BaseRequest' : self.loginInfo['BaseRequest'],
+ 'ChatRoomName' : chatroomUserName,
+ memberKeyName : memberList, }
+ headers = {
+ 'content-type': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT}
+ r = self.s.post(url, data=json.dumps(params),headers=headers)
+ return ReturnValue(rawResponse=r)
diff --git a/lib/itchat/async_components/hotreload.py b/lib/itchat/async_components/hotreload.py
new file mode 100644
index 0000000..b0bb54c
--- /dev/null
+++ b/lib/itchat/async_components/hotreload.py
@@ -0,0 +1,102 @@
+import pickle, os
+import logging
+
+import requests # type: ignore
+
+from ..config import VERSION
+from ..returnvalues import ReturnValue
+from ..storage import templates
+from .contact import update_local_chatrooms, update_local_friends
+from .messages import produce_msg
+
+logger = logging.getLogger('itchat')
+
+def load_hotreload(core):
+ core.dump_login_status = dump_login_status
+ core.load_login_status = load_login_status
+
+async def dump_login_status(self, fileDir=None):
+ fileDir = fileDir or self.hotReloadDir
+ try:
+ with open(fileDir, 'w') as f:
+ f.write('itchat - DELETE THIS')
+ os.remove(fileDir)
+ except:
+ raise Exception('Incorrect fileDir')
+ status = {
+ 'version' : VERSION,
+ 'loginInfo' : self.loginInfo,
+ 'cookies' : self.s.cookies.get_dict(),
+ 'storage' : self.storageClass.dumps()}
+ with open(fileDir, 'wb') as f:
+ pickle.dump(status, f)
+ logger.debug('Dump login status for hot reload successfully.')
+
+async def load_login_status(self, fileDir,
+ loginCallback=None, exitCallback=None):
+ try:
+ with open(fileDir, 'rb') as f:
+ j = pickle.load(f)
+ except Exception as e:
+ logger.debug('No such file, loading login status failed.')
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No such file, loading login status failed.',
+ 'Ret': -1002, }})
+
+ if j.get('version', '') != VERSION:
+ logger.debug(('you have updated itchat from %s to %s, ' +
+ 'so cached status is ignored') % (
+ j.get('version', 'old version'), VERSION))
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'cached status ignored because of version',
+ 'Ret': -1005, }})
+ self.loginInfo = j['loginInfo']
+ self.loginInfo['User'] = templates.User(self.loginInfo['User'])
+ self.loginInfo['User'].core = self
+ self.s.cookies = requests.utils.cookiejar_from_dict(j['cookies'])
+ self.storageClass.loads(j['storage'])
+ try:
+ msgList, contactList = self.get_msg()
+ except:
+ msgList = contactList = None
+ if (msgList or contactList) is None:
+ self.logout()
+ await load_last_login_status(self.s, j['cookies'])
+ logger.debug('server refused, loading login status failed.')
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'server refused, loading login status failed.',
+ 'Ret': -1003, }})
+ else:
+ if contactList:
+ for contact in contactList:
+ if '@@' in contact['UserName']:
+ update_local_chatrooms(self, [contact])
+ else:
+ update_local_friends(self, [contact])
+ if msgList:
+ msgList = produce_msg(self, msgList)
+ for msg in msgList: self.msgList.put(msg)
+ await self.start_receiving(exitCallback)
+ logger.debug('loading login status succeeded.')
+ if hasattr(loginCallback, '__call__'):
+ await loginCallback(self.storageClass.userName)
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'loading login status succeeded.',
+ 'Ret': 0, }})
+
+async def load_last_login_status(session, cookiesDict):
+ try:
+ session.cookies = requests.utils.cookiejar_from_dict({
+ 'webwxuvid': cookiesDict['webwxuvid'],
+ 'webwx_auth_ticket': cookiesDict['webwx_auth_ticket'],
+ 'login_frequency': '2',
+ 'last_wxuin': cookiesDict['wxuin'],
+ 'wxloadtime': cookiesDict['wxloadtime'] + '_expired',
+ 'wxpluginkey': cookiesDict['wxloadtime'],
+ 'wxuin': cookiesDict['wxuin'],
+ 'mm_lang': 'zh_CN',
+ 'MM_WX_NOTIFY_STATE': '1',
+ 'MM_WX_SOUND_STATE': '1', })
+ except:
+ logger.info('Load status for push login failed, we may have experienced a cookies change.')
+ logger.info('If you are using the newest version of itchat, you may report a bug.')
diff --git a/lib/itchat/async_components/login.py b/lib/itchat/async_components/login.py
new file mode 100644
index 0000000..59f3542
--- /dev/null
+++ b/lib/itchat/async_components/login.py
@@ -0,0 +1,422 @@
+import asyncio
+import os, time, re, io
+import threading
+import json
+import random
+import traceback
+import logging
+try:
+ from httplib import BadStatusLine
+except ImportError:
+ from http.client import BadStatusLine
+
+import requests # type: ignore
+from pyqrcode import QRCode
+
+from .. import config, utils
+from ..returnvalues import ReturnValue
+from ..storage.templates import wrap_user_dict
+from .contact import update_local_chatrooms, update_local_friends
+from .messages import produce_msg
+
+logger = logging.getLogger('itchat')
+
+
+def load_login(core):
+ core.login = login
+ core.get_QRuuid = get_QRuuid
+ core.get_QR = get_QR
+ core.check_login = check_login
+ core.web_init = web_init
+ core.show_mobile_login = show_mobile_login
+ core.start_receiving = start_receiving
+ core.get_msg = get_msg
+ core.logout = logout
+
+async def login(self, enableCmdQR=False, picDir=None, qrCallback=None, EventScanPayload=None,ScanStatus=None,event_stream=None,
+ loginCallback=None, exitCallback=None):
+ if self.alive or self.isLogging:
+ logger.warning('itchat has already logged in.')
+ return
+ self.isLogging = True
+
+ while self.isLogging:
+ uuid = await push_login(self)
+ if uuid:
+ payload = EventScanPayload(
+ status=ScanStatus.Waiting,
+ qrcode=f"qrcode/https://login.weixin.qq.com/l/{uuid}"
+ )
+ event_stream.emit('scan', payload)
+ await asyncio.sleep(0.1)
+ else:
+ logger.info('Getting uuid of QR code.')
+ self.get_QRuuid()
+ payload = EventScanPayload(
+ status=ScanStatus.Waiting,
+ qrcode=f"https://login.weixin.qq.com/l/{self.uuid}"
+ )
+ print(f"https://wechaty.js.org/qrcode/https://login.weixin.qq.com/l/{self.uuid}")
+ event_stream.emit('scan', payload)
+ await asyncio.sleep(0.1)
+ # logger.info('Please scan the QR code to log in.')
+ isLoggedIn = False
+ while not isLoggedIn:
+ status = await self.check_login()
+ # if hasattr(qrCallback, '__call__'):
+ # await qrCallback(uuid=self.uuid, status=status, qrcode=self.qrStorage.getvalue())
+ if status == '200':
+ isLoggedIn = True
+ payload = EventScanPayload(
+ status=ScanStatus.Scanned,
+ qrcode=f"https://login.weixin.qq.com/l/{self.uuid}"
+ )
+ event_stream.emit('scan', payload)
+ await asyncio.sleep(0.1)
+ elif status == '201':
+ if isLoggedIn is not None:
+ logger.info('Please press confirm on your phone.')
+ isLoggedIn = None
+ payload = EventScanPayload(
+ status=ScanStatus.Waiting,
+ qrcode=f"https://login.weixin.qq.com/l/{self.uuid}"
+ )
+ event_stream.emit('scan', payload)
+ await asyncio.sleep(0.1)
+ elif status != '408':
+ payload = EventScanPayload(
+ status=ScanStatus.Cancel,
+ qrcode=f"https://login.weixin.qq.com/l/{self.uuid}"
+ )
+ event_stream.emit('scan', payload)
+ await asyncio.sleep(0.1)
+ break
+ if isLoggedIn:
+ payload = EventScanPayload(
+ status=ScanStatus.Confirmed,
+ qrcode=f"https://login.weixin.qq.com/l/{self.uuid}"
+ )
+ event_stream.emit('scan', payload)
+ await asyncio.sleep(0.1)
+ break
+ elif self.isLogging:
+ logger.info('Log in time out, reloading QR code.')
+ payload = EventScanPayload(
+ status=ScanStatus.Timeout,
+ qrcode=f"https://login.weixin.qq.com/l/{self.uuid}"
+ )
+ event_stream.emit('scan', payload)
+ await asyncio.sleep(0.1)
+ else:
+ return
+ logger.info('Loading the contact, this may take a little while.')
+ await self.web_init()
+ await self.show_mobile_login()
+ self.get_contact(True)
+ if hasattr(loginCallback, '__call__'):
+ r = await loginCallback(self.storageClass.userName)
+ else:
+ utils.clear_screen()
+ if os.path.exists(picDir or config.DEFAULT_QR):
+ os.remove(picDir or config.DEFAULT_QR)
+ logger.info('Login successfully as %s' % self.storageClass.nickName)
+ await self.start_receiving(exitCallback)
+ self.isLogging = False
+
+async def push_login(core):
+ cookiesDict = core.s.cookies.get_dict()
+ if 'wxuin' in cookiesDict:
+ url = '%s/cgi-bin/mmwebwx-bin/webwxpushloginurl?uin=%s' % (
+ config.BASE_URL, cookiesDict['wxuin'])
+ headers = { 'User-Agent' : config.USER_AGENT}
+ r = core.s.get(url, headers=headers).json()
+ if 'uuid' in r and r.get('ret') in (0, '0'):
+ core.uuid = r['uuid']
+ return r['uuid']
+ return False
+
+def get_QRuuid(self):
+ url = '%s/jslogin' % config.BASE_URL
+ params = {
+ 'appid' : 'wx782c26e4c19acffb',
+ 'fun' : 'new',
+ 'redirect_uri' : 'https://wx.qq.com/cgi-bin/mmwebwx-bin/webwxnewloginpage?mod=desktop',
+ 'lang' : 'zh_CN' }
+ headers = { 'User-Agent' : config.USER_AGENT}
+ r = self.s.get(url, params=params, headers=headers)
+ regx = r'window.QRLogin.code = (\d+); window.QRLogin.uuid = "(\S+?)";'
+ data = re.search(regx, r.text)
+ if data and data.group(1) == '200':
+ self.uuid = data.group(2)
+ return self.uuid
+
+async def get_QR(self, uuid=None, enableCmdQR=False, picDir=None, qrCallback=None):
+ uuid = uuid or self.uuid
+ picDir = picDir or config.DEFAULT_QR
+ qrStorage = io.BytesIO()
+ qrCode = QRCode('https://login.weixin.qq.com/l/' + uuid)
+ qrCode.png(qrStorage, scale=10)
+ if hasattr(qrCallback, '__call__'):
+ await qrCallback(uuid=uuid, status='0', qrcode=qrStorage.getvalue())
+ else:
+ with open(picDir, 'wb') as f:
+ f.write(qrStorage.getvalue())
+ if enableCmdQR:
+ utils.print_cmd_qr(qrCode.text(1), enableCmdQR=enableCmdQR)
+ else:
+ utils.print_qr(picDir)
+ return qrStorage
+
+async def check_login(self, uuid=None):
+ uuid = uuid or self.uuid
+ url = '%s/cgi-bin/mmwebwx-bin/login' % config.BASE_URL
+ localTime = int(time.time())
+ params = 'loginicon=true&uuid=%s&tip=1&r=%s&_=%s' % (
+ uuid, int(-localTime / 1579), localTime)
+ headers = { 'User-Agent' : config.USER_AGENT}
+ r = self.s.get(url, params=params, headers=headers)
+ regx = r'window.code=(\d+)'
+ data = re.search(regx, r.text)
+ if data and data.group(1) == '200':
+ if await process_login_info(self, r.text):
+ return '200'
+ else:
+ return '400'
+ elif data:
+ return data.group(1)
+ else:
+ return '400'
+
+async def process_login_info(core, loginContent):
+ ''' when finish login (scanning qrcode)
+ * syncUrl and fileUploadingUrl will be fetched
+ * deviceid and msgid will be generated
+ * skey, wxsid, wxuin, pass_ticket will be fetched
+ '''
+ regx = r'window.redirect_uri="(\S+)";'
+ core.loginInfo['url'] = re.search(regx, loginContent).group(1)
+ headers = { 'User-Agent' : config.USER_AGENT,
+ 'client-version' : config.UOS_PATCH_CLIENT_VERSION,
+ 'extspam' : config.UOS_PATCH_EXTSPAM,
+ 'referer' : 'https://wx.qq.com/?&lang=zh_CN&target=t'
+ }
+ r = core.s.get(core.loginInfo['url'], headers=headers, allow_redirects=False)
+ core.loginInfo['url'] = core.loginInfo['url'][:core.loginInfo['url'].rfind('/')]
+ for indexUrl, detailedUrl in (
+ ("wx2.qq.com" , ("file.wx2.qq.com", "webpush.wx2.qq.com")),
+ ("wx8.qq.com" , ("file.wx8.qq.com", "webpush.wx8.qq.com")),
+ ("qq.com" , ("file.wx.qq.com", "webpush.wx.qq.com")),
+ ("web2.wechat.com" , ("file.web2.wechat.com", "webpush.web2.wechat.com")),
+ ("wechat.com" , ("file.web.wechat.com", "webpush.web.wechat.com"))):
+ fileUrl, syncUrl = ['https://%s/cgi-bin/mmwebwx-bin' % url for url in detailedUrl]
+ if indexUrl in core.loginInfo['url']:
+ core.loginInfo['fileUrl'], core.loginInfo['syncUrl'] = \
+ fileUrl, syncUrl
+ break
+ else:
+ core.loginInfo['fileUrl'] = core.loginInfo['syncUrl'] = core.loginInfo['url']
+ core.loginInfo['deviceid'] = 'e' + repr(random.random())[2:17]
+ core.loginInfo['logintime'] = int(time.time() * 1e3)
+ core.loginInfo['BaseRequest'] = {}
+ cookies = core.s.cookies.get_dict()
+ skey = re.findall('(.*?)', r.text, re.S)[0]
+ pass_ticket = re.findall('(.*?)', r.text, re.S)[0]
+ core.loginInfo['skey'] = core.loginInfo['BaseRequest']['Skey'] = skey
+ core.loginInfo['wxsid'] = core.loginInfo['BaseRequest']['Sid'] = cookies["wxsid"]
+ core.loginInfo['wxuin'] = core.loginInfo['BaseRequest']['Uin'] = cookies["wxuin"]
+ core.loginInfo['pass_ticket'] = pass_ticket
+
+ # A question : why pass_ticket == DeviceID ?
+ # deviceID is only a randomly generated number
+
+ # UOS PATCH By luvletter2333, Sun Feb 28 10:00 PM
+ # for node in xml.dom.minidom.parseString(r.text).documentElement.childNodes:
+ # if node.nodeName == 'skey':
+ # core.loginInfo['skey'] = core.loginInfo['BaseRequest']['Skey'] = node.childNodes[0].data
+ # elif node.nodeName == 'wxsid':
+ # core.loginInfo['wxsid'] = core.loginInfo['BaseRequest']['Sid'] = node.childNodes[0].data
+ # elif node.nodeName == 'wxuin':
+ # core.loginInfo['wxuin'] = core.loginInfo['BaseRequest']['Uin'] = node.childNodes[0].data
+ # elif node.nodeName == 'pass_ticket':
+ # core.loginInfo['pass_ticket'] = core.loginInfo['BaseRequest']['DeviceID'] = node.childNodes[0].data
+ if not all([key in core.loginInfo for key in ('skey', 'wxsid', 'wxuin', 'pass_ticket')]):
+ logger.error('Your wechat account may be LIMITED to log in WEB wechat, error info:\n%s' % r.text)
+ core.isLogging = False
+ return False
+ return True
+
+async def web_init(self):
+ url = '%s/webwxinit' % self.loginInfo['url']
+ params = {
+ 'r': int(-time.time() / 1579),
+ 'pass_ticket': self.loginInfo['pass_ticket'], }
+ data = { 'BaseRequest': self.loginInfo['BaseRequest'], }
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT, }
+ r = self.s.post(url, params=params, data=json.dumps(data), headers=headers)
+ dic = json.loads(r.content.decode('utf-8', 'replace'))
+ # deal with login info
+ utils.emoji_formatter(dic['User'], 'NickName')
+ self.loginInfo['InviteStartCount'] = int(dic['InviteStartCount'])
+ self.loginInfo['User'] = wrap_user_dict(utils.struct_friend_info(dic['User']))
+ self.memberList.append(self.loginInfo['User'])
+ self.loginInfo['SyncKey'] = dic['SyncKey']
+ self.loginInfo['synckey'] = '|'.join(['%s_%s' % (item['Key'], item['Val'])
+ for item in dic['SyncKey']['List']])
+ self.storageClass.userName = dic['User']['UserName']
+ self.storageClass.nickName = dic['User']['NickName']
+ # deal with contact list returned when init
+ contactList = dic.get('ContactList', [])
+ chatroomList, otherList = [], []
+ for m in contactList:
+ if m['Sex'] != 0:
+ otherList.append(m)
+ elif '@@' in m['UserName']:
+ m['MemberList'] = [] # don't let dirty info pollute the list
+ chatroomList.append(m)
+ elif '@' in m['UserName']:
+ # mp will be dealt in update_local_friends as well
+ otherList.append(m)
+ if chatroomList:
+ update_local_chatrooms(self, chatroomList)
+ if otherList:
+ update_local_friends(self, otherList)
+ return dic
+
+async def show_mobile_login(self):
+ url = '%s/webwxstatusnotify?lang=zh_CN&pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'])
+ data = {
+ 'BaseRequest' : self.loginInfo['BaseRequest'],
+ 'Code' : 3,
+ 'FromUserName' : self.storageClass.userName,
+ 'ToUserName' : self.storageClass.userName,
+ 'ClientMsgId' : int(time.time()), }
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT, }
+ r = self.s.post(url, data=json.dumps(data), headers=headers)
+ return ReturnValue(rawResponse=r)
+
+async def start_receiving(self, exitCallback=None, getReceivingFnOnly=False):
+ self.alive = True
+ def maintain_loop():
+ retryCount = 0
+ while self.alive:
+ try:
+ i = sync_check(self)
+ if i is None:
+ self.alive = False
+ elif i == '0':
+ pass
+ else:
+ msgList, contactList = self.get_msg()
+ if msgList:
+ msgList = produce_msg(self, msgList)
+ for msg in msgList:
+ self.msgList.put(msg)
+ if contactList:
+ chatroomList, otherList = [], []
+ for contact in contactList:
+ if '@@' in contact['UserName']:
+ chatroomList.append(contact)
+ else:
+ otherList.append(contact)
+ chatroomMsg = update_local_chatrooms(self, chatroomList)
+ chatroomMsg['User'] = self.loginInfo['User']
+ self.msgList.put(chatroomMsg)
+ update_local_friends(self, otherList)
+ retryCount = 0
+ except requests.exceptions.ReadTimeout:
+ pass
+ except:
+ retryCount += 1
+ logger.error(traceback.format_exc())
+ if self.receivingRetryCount < retryCount:
+ self.alive = False
+ else:
+ time.sleep(1)
+ self.logout()
+ if hasattr(exitCallback, '__call__'):
+ exitCallback(self.storageClass.userName)
+ else:
+ logger.info('LOG OUT!')
+ if getReceivingFnOnly:
+ return maintain_loop
+ else:
+ maintainThread = threading.Thread(target=maintain_loop)
+ maintainThread.setDaemon(True)
+ maintainThread.start()
+
+def sync_check(self):
+ url = '%s/synccheck' % self.loginInfo.get('syncUrl', self.loginInfo['url'])
+ params = {
+ 'r' : int(time.time() * 1000),
+ 'skey' : self.loginInfo['skey'],
+ 'sid' : self.loginInfo['wxsid'],
+ 'uin' : self.loginInfo['wxuin'],
+ 'deviceid' : self.loginInfo['deviceid'],
+ 'synckey' : self.loginInfo['synckey'],
+ '_' : self.loginInfo['logintime'], }
+ headers = { 'User-Agent' : config.USER_AGENT}
+ self.loginInfo['logintime'] += 1
+ try:
+ r = self.s.get(url, params=params, headers=headers, timeout=config.TIMEOUT)
+ except requests.exceptions.ConnectionError as e:
+ try:
+ if not isinstance(e.args[0].args[1], BadStatusLine):
+ raise
+ # will return a package with status '0 -'
+ # and value like:
+ # 6f:00:8a:9c:09:74:e4:d8:e0:14:bf:96:3a:56:a0:64:1b:a4:25:5d:12:f4:31:a5:30:f1:c6:48:5f:c3:75:6a:99:93
+ # seems like status of typing, but before I make further achievement code will remain like this
+ return '2'
+ except:
+ raise
+ r.raise_for_status()
+ regx = r'window.synccheck={retcode:"(\d+)",selector:"(\d+)"}'
+ pm = re.search(regx, r.text)
+ if pm is None or pm.group(1) != '0':
+ logger.debug('Unexpected sync check result: %s' % r.text)
+ return None
+ return pm.group(2)
+
+def get_msg(self):
+ self.loginInfo['deviceid'] = 'e' + repr(random.random())[2:17]
+ url = '%s/webwxsync?sid=%s&skey=%s&pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['wxsid'],
+ self.loginInfo['skey'],self.loginInfo['pass_ticket'])
+ data = {
+ 'BaseRequest' : self.loginInfo['BaseRequest'],
+ 'SyncKey' : self.loginInfo['SyncKey'],
+ 'rr' : ~int(time.time()), }
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT }
+ r = self.s.post(url, data=json.dumps(data), headers=headers, timeout=config.TIMEOUT)
+ dic = json.loads(r.content.decode('utf-8', 'replace'))
+ if dic['BaseResponse']['Ret'] != 0: return None, None
+ self.loginInfo['SyncKey'] = dic['SyncKey']
+ self.loginInfo['synckey'] = '|'.join(['%s_%s' % (item['Key'], item['Val'])
+ for item in dic['SyncCheckKey']['List']])
+ return dic['AddMsgList'], dic['ModContactList']
+
+def logout(self):
+ if self.alive:
+ url = '%s/webwxlogout' % self.loginInfo['url']
+ params = {
+ 'redirect' : 1,
+ 'type' : 1,
+ 'skey' : self.loginInfo['skey'], }
+ headers = { 'User-Agent' : config.USER_AGENT}
+ self.s.get(url, params=params, headers=headers)
+ self.alive = False
+ self.isLogging = False
+ self.s.cookies.clear()
+ del self.chatroomList[:]
+ del self.memberList[:]
+ del self.mpList[:]
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'logout successfully.',
+ 'Ret': 0, }})
diff --git a/lib/itchat/async_components/messages.py b/lib/itchat/async_components/messages.py
new file mode 100644
index 0000000..f842f1f
--- /dev/null
+++ b/lib/itchat/async_components/messages.py
@@ -0,0 +1,527 @@
+import os, time, re, io
+import json
+import mimetypes, hashlib
+import logging
+from collections import OrderedDict
+
+
+from .. import config, utils
+from ..returnvalues import ReturnValue
+from ..storage import templates
+from .contact import update_local_uin
+
+logger = logging.getLogger('itchat')
+
+def load_messages(core):
+ core.send_raw_msg = send_raw_msg
+ core.send_msg = send_msg
+ core.upload_file = upload_file
+ core.send_file = send_file
+ core.send_image = send_image
+ core.send_video = send_video
+ core.send = send
+ core.revoke = revoke
+
+async def get_download_fn(core, url, msgId):
+ async def download_fn(downloadDir=None):
+ params = {
+ 'msgid': msgId,
+ 'skey': core.loginInfo['skey'],}
+ headers = { 'User-Agent' : config.USER_AGENT}
+ r = core.s.get(url, params=params, stream=True, headers = headers)
+ tempStorage = io.BytesIO()
+ for block in r.iter_content(1024):
+ tempStorage.write(block)
+ if downloadDir is None:
+ return tempStorage.getvalue()
+ with open(downloadDir, 'wb') as f:
+ f.write(tempStorage.getvalue())
+ tempStorage.seek(0)
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Successfully downloaded',
+ 'Ret': 0, },
+ 'PostFix': utils.get_image_postfix(tempStorage.read(20)), })
+ return download_fn
+
+def produce_msg(core, msgList):
+ ''' for messages types
+ * 40 msg, 43 videochat, 50 VOIPMSG, 52 voipnotifymsg
+ * 53 webwxvoipnotifymsg, 9999 sysnotice
+ '''
+ rl = []
+ srl = [40, 43, 50, 52, 53, 9999]
+ for m in msgList:
+ # get actual opposite
+ if m['FromUserName'] == core.storageClass.userName:
+ actualOpposite = m['ToUserName']
+ else:
+ actualOpposite = m['FromUserName']
+ # produce basic message
+ if '@@' in m['FromUserName'] or '@@' in m['ToUserName']:
+ produce_group_chat(core, m)
+ else:
+ utils.msg_formatter(m, 'Content')
+ # set user of msg
+ if '@@' in actualOpposite:
+ m['User'] = core.search_chatrooms(userName=actualOpposite) or \
+ templates.Chatroom({'UserName': actualOpposite})
+ # we don't need to update chatroom here because we have
+ # updated once when producing basic message
+ elif actualOpposite in ('filehelper', 'fmessage'):
+ m['User'] = templates.User({'UserName': actualOpposite})
+ else:
+ m['User'] = core.search_mps(userName=actualOpposite) or \
+ core.search_friends(userName=actualOpposite) or \
+ templates.User(userName=actualOpposite)
+ # by default we think there may be a user missing not a mp
+ m['User'].core = core
+ if m['MsgType'] == 1: # words
+ if m['Url']:
+ regx = r'(.+?\(.+?\))'
+ data = re.search(regx, m['Content'])
+ data = 'Map' if data is None else data.group(1)
+ msg = {
+ 'Type': 'Map',
+ 'Text': data,}
+ else:
+ msg = {
+ 'Type': 'Text',
+ 'Text': m['Content'],}
+ elif m['MsgType'] == 3 or m['MsgType'] == 47: # picture
+ download_fn = get_download_fn(core,
+ '%s/webwxgetmsgimg' % core.loginInfo['url'], m['NewMsgId'])
+ msg = {
+ 'Type' : 'Picture',
+ 'FileName' : '%s.%s' % (time.strftime('%y%m%d-%H%M%S', time.localtime()),
+ 'png' if m['MsgType'] == 3 else 'gif'),
+ 'Text' : download_fn, }
+ elif m['MsgType'] == 34: # voice
+ download_fn = get_download_fn(core,
+ '%s/webwxgetvoice' % core.loginInfo['url'], m['NewMsgId'])
+ msg = {
+ 'Type': 'Recording',
+ 'FileName' : '%s.mp3' % time.strftime('%y%m%d-%H%M%S', time.localtime()),
+ 'Text': download_fn,}
+ elif m['MsgType'] == 37: # friends
+ m['User']['UserName'] = m['RecommendInfo']['UserName']
+ msg = {
+ 'Type': 'Friends',
+ 'Text': {
+ 'status' : m['Status'],
+ 'userName' : m['RecommendInfo']['UserName'],
+ 'verifyContent' : m['Ticket'],
+ 'autoUpdate' : m['RecommendInfo'], }, }
+ m['User'].verifyDict = msg['Text']
+ elif m['MsgType'] == 42: # name card
+ msg = {
+ 'Type': 'Card',
+ 'Text': m['RecommendInfo'], }
+ elif m['MsgType'] in (43, 62): # tiny video
+ msgId = m['MsgId']
+ async def download_video(videoDir=None):
+ url = '%s/webwxgetvideo' % core.loginInfo['url']
+ params = {
+ 'msgid': msgId,
+ 'skey': core.loginInfo['skey'],}
+ headers = {'Range': 'bytes=0-', 'User-Agent' : config.USER_AGENT}
+ r = core.s.get(url, params=params, headers=headers, stream=True)
+ tempStorage = io.BytesIO()
+ for block in r.iter_content(1024):
+ tempStorage.write(block)
+ if videoDir is None:
+ return tempStorage.getvalue()
+ with open(videoDir, 'wb') as f:
+ f.write(tempStorage.getvalue())
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Successfully downloaded',
+ 'Ret': 0, }})
+ msg = {
+ 'Type': 'Video',
+ 'FileName' : '%s.mp4' % time.strftime('%y%m%d-%H%M%S', time.localtime()),
+ 'Text': download_video, }
+ elif m['MsgType'] == 49: # sharing
+ if m['AppMsgType'] == 0: # chat history
+ msg = {
+ 'Type': 'Note',
+ 'Text': m['Content'], }
+ elif m['AppMsgType'] == 6:
+ rawMsg = m
+ cookiesList = {name:data for name,data in core.s.cookies.items()}
+ async def download_atta(attaDir=None):
+ url = core.loginInfo['fileUrl'] + '/webwxgetmedia'
+ params = {
+ 'sender': rawMsg['FromUserName'],
+ 'mediaid': rawMsg['MediaId'],
+ 'filename': rawMsg['FileName'],
+ 'fromuser': core.loginInfo['wxuin'],
+ 'pass_ticket': 'undefined',
+ 'webwx_data_ticket': cookiesList['webwx_data_ticket'],}
+ headers = { 'User-Agent' : config.USER_AGENT}
+ r = core.s.get(url, params=params, stream=True, headers=headers)
+ tempStorage = io.BytesIO()
+ for block in r.iter_content(1024):
+ tempStorage.write(block)
+ if attaDir is None:
+ return tempStorage.getvalue()
+ with open(attaDir, 'wb') as f:
+ f.write(tempStorage.getvalue())
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Successfully downloaded',
+ 'Ret': 0, }})
+ msg = {
+ 'Type': 'Attachment',
+ 'Text': download_atta, }
+ elif m['AppMsgType'] == 8:
+ download_fn = get_download_fn(core,
+ '%s/webwxgetmsgimg' % core.loginInfo['url'], m['NewMsgId'])
+ msg = {
+ 'Type' : 'Picture',
+ 'FileName' : '%s.gif' % (
+ time.strftime('%y%m%d-%H%M%S', time.localtime())),
+ 'Text' : download_fn, }
+ elif m['AppMsgType'] == 17:
+ msg = {
+ 'Type': 'Note',
+ 'Text': m['FileName'], }
+ elif m['AppMsgType'] == 2000:
+ regx = r'\[CDATA\[(.+?)\][\s\S]+?\[CDATA\[(.+?)\]'
+ data = re.search(regx, m['Content'])
+ if data:
+ data = data.group(2).split(u'\u3002')[0]
+ else:
+ data = 'You may found detailed info in Content key.'
+ msg = {
+ 'Type': 'Note',
+ 'Text': data, }
+ else:
+ msg = {
+ 'Type': 'Sharing',
+ 'Text': m['FileName'], }
+ elif m['MsgType'] == 51: # phone init
+ msg = update_local_uin(core, m)
+ elif m['MsgType'] == 10000:
+ msg = {
+ 'Type': 'Note',
+ 'Text': m['Content'],}
+ elif m['MsgType'] == 10002:
+ regx = r'\[CDATA\[(.+?)\]\]'
+ data = re.search(regx, m['Content'])
+ data = 'System message' if data is None else data.group(1).replace('\\', '')
+ msg = {
+ 'Type': 'Note',
+ 'Text': data, }
+ elif m['MsgType'] in srl:
+ msg = {
+ 'Type': 'Useless',
+ 'Text': 'UselessMsg', }
+ else:
+ logger.debug('Useless message received: %s\n%s' % (m['MsgType'], str(m)))
+ msg = {
+ 'Type': 'Useless',
+ 'Text': 'UselessMsg', }
+ m = dict(m, **msg)
+ rl.append(m)
+ return rl
+
+def produce_group_chat(core, msg):
+ r = re.match('(@[0-9a-z]*?):
(.*)$', msg['Content'])
+ if r:
+ actualUserName, content = r.groups()
+ chatroomUserName = msg['FromUserName']
+ elif msg['FromUserName'] == core.storageClass.userName:
+ actualUserName = core.storageClass.userName
+ content = msg['Content']
+ chatroomUserName = msg['ToUserName']
+ else:
+ msg['ActualUserName'] = core.storageClass.userName
+ msg['ActualNickName'] = core.storageClass.nickName
+ msg['IsAt'] = False
+ utils.msg_formatter(msg, 'Content')
+ return
+ chatroom = core.storageClass.search_chatrooms(userName=chatroomUserName)
+ member = utils.search_dict_list((chatroom or {}).get(
+ 'MemberList') or [], 'UserName', actualUserName)
+ if member is None:
+ chatroom = core.update_chatroom(chatroomUserName)
+ member = utils.search_dict_list((chatroom or {}).get(
+ 'MemberList') or [], 'UserName', actualUserName)
+ if member is None:
+ logger.debug('chatroom member fetch failed with %s' % actualUserName)
+ msg['ActualNickName'] = ''
+ msg['IsAt'] = False
+ else:
+ msg['ActualNickName'] = member.get('DisplayName', '') or member['NickName']
+ atFlag = '@' + (chatroom['Self'].get('DisplayName', '') or core.storageClass.nickName)
+ msg['IsAt'] = (
+ (atFlag + (u'\u2005' if u'\u2005' in msg['Content'] else ' '))
+ in msg['Content'] or msg['Content'].endswith(atFlag))
+ msg['ActualUserName'] = actualUserName
+ msg['Content'] = content
+ utils.msg_formatter(msg, 'Content')
+
+async def send_raw_msg(self, msgType, content, toUserName):
+ url = '%s/webwxsendmsg' % self.loginInfo['url']
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Msg': {
+ 'Type': msgType,
+ 'Content': content,
+ 'FromUserName': self.storageClass.userName,
+ 'ToUserName': (toUserName if toUserName else self.storageClass.userName),
+ 'LocalID': int(time.time() * 1e4),
+ 'ClientMsgId': int(time.time() * 1e4),
+ },
+ 'Scene': 0, }
+ headers = { 'ContentType': 'application/json; charset=UTF-8', 'User-Agent' : config.USER_AGENT}
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8'))
+ return ReturnValue(rawResponse=r)
+
+async def send_msg(self, msg='Test Message', toUserName=None):
+ logger.debug('Request to send a text message to %s: %s' % (toUserName, msg))
+ r = await self.send_raw_msg(1, msg, toUserName)
+ return r
+
+def _prepare_file(fileDir, file_=None):
+ fileDict = {}
+ if file_:
+ if hasattr(file_, 'read'):
+ file_ = file_.read()
+ else:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'file_ param should be opened file',
+ 'Ret': -1005, }})
+ else:
+ if not utils.check_file(fileDir):
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No file found in specific dir',
+ 'Ret': -1002, }})
+ with open(fileDir, 'rb') as f:
+ file_ = f.read()
+ fileDict['fileSize'] = len(file_)
+ fileDict['fileMd5'] = hashlib.md5(file_).hexdigest()
+ fileDict['file_'] = io.BytesIO(file_)
+ return fileDict
+
+def upload_file(self, fileDir, isPicture=False, isVideo=False,
+ toUserName='filehelper', file_=None, preparedFile=None):
+ logger.debug('Request to upload a %s: %s' % (
+ 'picture' if isPicture else 'video' if isVideo else 'file', fileDir))
+ if not preparedFile:
+ preparedFile = _prepare_file(fileDir, file_)
+ if not preparedFile:
+ return preparedFile
+ fileSize, fileMd5, file_ = \
+ preparedFile['fileSize'], preparedFile['fileMd5'], preparedFile['file_']
+ fileSymbol = 'pic' if isPicture else 'video' if isVideo else'doc'
+ chunks = int((fileSize - 1) / 524288) + 1
+ clientMediaId = int(time.time() * 1e4)
+ uploadMediaRequest = json.dumps(OrderedDict([
+ ('UploadType', 2),
+ ('BaseRequest', self.loginInfo['BaseRequest']),
+ ('ClientMediaId', clientMediaId),
+ ('TotalLen', fileSize),
+ ('StartPos', 0),
+ ('DataLen', fileSize),
+ ('MediaType', 4),
+ ('FromUserName', self.storageClass.userName),
+ ('ToUserName', toUserName),
+ ('FileMd5', fileMd5)]
+ ), separators = (',', ':'))
+ r = {'BaseResponse': {'Ret': -1005, 'ErrMsg': 'Empty file detected'}}
+ for chunk in range(chunks):
+ r = upload_chunk_file(self, fileDir, fileSymbol, fileSize,
+ file_, chunk, chunks, uploadMediaRequest)
+ file_.close()
+ if isinstance(r, dict):
+ return ReturnValue(r)
+ return ReturnValue(rawResponse=r)
+
+def upload_chunk_file(core, fileDir, fileSymbol, fileSize,
+ file_, chunk, chunks, uploadMediaRequest):
+ url = core.loginInfo.get('fileUrl', core.loginInfo['url']) + \
+ '/webwxuploadmedia?f=json'
+ # save it on server
+ cookiesList = {name:data for name,data in core.s.cookies.items()}
+ fileType = mimetypes.guess_type(fileDir)[0] or 'application/octet-stream'
+ fileName = utils.quote(os.path.basename(fileDir))
+ files = OrderedDict([
+ ('id', (None, 'WU_FILE_0')),
+ ('name', (None, fileName)),
+ ('type', (None, fileType)),
+ ('lastModifiedDate', (None, time.strftime('%a %b %d %Y %H:%M:%S GMT+0800 (CST)'))),
+ ('size', (None, str(fileSize))),
+ ('chunks', (None, None)),
+ ('chunk', (None, None)),
+ ('mediatype', (None, fileSymbol)),
+ ('uploadmediarequest', (None, uploadMediaRequest)),
+ ('webwx_data_ticket', (None, cookiesList['webwx_data_ticket'])),
+ ('pass_ticket', (None, core.loginInfo['pass_ticket'])),
+ ('filename' , (fileName, file_.read(524288), 'application/octet-stream'))])
+ if chunks == 1:
+ del files['chunk']; del files['chunks']
+ else:
+ files['chunk'], files['chunks'] = (None, str(chunk)), (None, str(chunks))
+ headers = { 'User-Agent' : config.USER_AGENT}
+ return core.s.post(url, files=files, headers=headers, timeout=config.TIMEOUT)
+
+async def send_file(self, fileDir, toUserName=None, mediaId=None, file_=None):
+ logger.debug('Request to send a file(mediaId: %s) to %s: %s' % (
+ mediaId, toUserName, fileDir))
+ if hasattr(fileDir, 'read'):
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'fileDir param should not be an opened file in send_file',
+ 'Ret': -1005, }})
+ if toUserName is None:
+ toUserName = self.storageClass.userName
+ preparedFile = _prepare_file(fileDir, file_)
+ if not preparedFile:
+ return preparedFile
+ fileSize = preparedFile['fileSize']
+ if mediaId is None:
+ r = self.upload_file(fileDir, preparedFile=preparedFile)
+ if r:
+ mediaId = r['MediaId']
+ else:
+ return r
+ url = '%s/webwxsendappmsg?fun=async&f=json' % self.loginInfo['url']
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Msg': {
+ 'Type': 6,
+ 'Content': ("%s" % os.path.basename(fileDir) +
+ "6" +
+ "%s%s" % (str(fileSize), mediaId) +
+ "%s" % os.path.splitext(fileDir)[1].replace('.','')),
+ 'FromUserName': self.storageClass.userName,
+ 'ToUserName': toUserName,
+ 'LocalID': int(time.time() * 1e4),
+ 'ClientMsgId': int(time.time() * 1e4), },
+ 'Scene': 0, }
+ headers = {
+ 'User-Agent': config.USER_AGENT,
+ 'Content-Type': 'application/json;charset=UTF-8', }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8'))
+ return ReturnValue(rawResponse=r)
+
+async def send_image(self, fileDir=None, toUserName=None, mediaId=None, file_=None):
+ logger.debug('Request to send a image(mediaId: %s) to %s: %s' % (
+ mediaId, toUserName, fileDir))
+ if fileDir or file_:
+ if hasattr(fileDir, 'read'):
+ file_, fileDir = fileDir, None
+ if fileDir is None:
+ fileDir = 'tmp.jpg' # specific fileDir to send gifs
+ else:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Either fileDir or file_ should be specific',
+ 'Ret': -1005, }})
+ if toUserName is None:
+ toUserName = self.storageClass.userName
+ if mediaId is None:
+ r = self.upload_file(fileDir, isPicture=not fileDir[-4:] == '.gif', file_=file_)
+ if r:
+ mediaId = r['MediaId']
+ else:
+ return r
+ url = '%s/webwxsendmsgimg?fun=async&f=json' % self.loginInfo['url']
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Msg': {
+ 'Type': 3,
+ 'MediaId': mediaId,
+ 'FromUserName': self.storageClass.userName,
+ 'ToUserName': toUserName,
+ 'LocalID': int(time.time() * 1e4),
+ 'ClientMsgId': int(time.time() * 1e4), },
+ 'Scene': 0, }
+ if fileDir[-4:] == '.gif':
+ url = '%s/webwxsendemoticon?fun=sys' % self.loginInfo['url']
+ data['Msg']['Type'] = 47
+ data['Msg']['EmojiFlag'] = 2
+ headers = {
+ 'User-Agent': config.USER_AGENT,
+ 'Content-Type': 'application/json;charset=UTF-8', }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8'))
+ return ReturnValue(rawResponse=r)
+
+async def send_video(self, fileDir=None, toUserName=None, mediaId=None, file_=None):
+ logger.debug('Request to send a video(mediaId: %s) to %s: %s' % (
+ mediaId, toUserName, fileDir))
+ if fileDir or file_:
+ if hasattr(fileDir, 'read'):
+ file_, fileDir = fileDir, None
+ if fileDir is None:
+ fileDir = 'tmp.mp4' # specific fileDir to send other formats
+ else:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Either fileDir or file_ should be specific',
+ 'Ret': -1005, }})
+ if toUserName is None:
+ toUserName = self.storageClass.userName
+ if mediaId is None:
+ r = self.upload_file(fileDir, isVideo=True, file_=file_)
+ if r:
+ mediaId = r['MediaId']
+ else:
+ return r
+ url = '%s/webwxsendvideomsg?fun=async&f=json&pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'])
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Msg': {
+ 'Type' : 43,
+ 'MediaId' : mediaId,
+ 'FromUserName' : self.storageClass.userName,
+ 'ToUserName' : toUserName,
+ 'LocalID' : int(time.time() * 1e4),
+ 'ClientMsgId' : int(time.time() * 1e4), },
+ 'Scene': 0, }
+ headers = {
+ 'User-Agent' : config.USER_AGENT,
+ 'Content-Type': 'application/json;charset=UTF-8', }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8'))
+ return ReturnValue(rawResponse=r)
+
+async def send(self, msg, toUserName=None, mediaId=None):
+ if not msg:
+ r = ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No message.',
+ 'Ret': -1005, }})
+ elif msg[:5] == '@fil@':
+ if mediaId is None:
+ r = await self.send_file(msg[5:], toUserName)
+ else:
+ r = await self.send_file(msg[5:], toUserName, mediaId)
+ elif msg[:5] == '@img@':
+ if mediaId is None:
+ r = await self.send_image(msg[5:], toUserName)
+ else:
+ r = await self.send_image(msg[5:], toUserName, mediaId)
+ elif msg[:5] == '@msg@':
+ r = await self.send_msg(msg[5:], toUserName)
+ elif msg[:5] == '@vid@':
+ if mediaId is None:
+ r = await self.send_video(msg[5:], toUserName)
+ else:
+ r = await self.send_video(msg[5:], toUserName, mediaId)
+ else:
+ r = await self.send_msg(msg, toUserName)
+ return r
+
+async def revoke(self, msgId, toUserName, localId=None):
+ url = '%s/webwxrevokemsg' % self.loginInfo['url']
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ "ClientMsgId": localId or str(time.time() * 1e3),
+ "SvrMsgId": msgId,
+ "ToUserName": toUserName}
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8'))
+ return ReturnValue(rawResponse=r)
diff --git a/lib/itchat/async_components/register.py b/lib/itchat/async_components/register.py
new file mode 100644
index 0000000..cb4f12b
--- /dev/null
+++ b/lib/itchat/async_components/register.py
@@ -0,0 +1,106 @@
+import logging, traceback, sys, threading
+try:
+ import Queue
+except ImportError:
+ import queue as Queue # type: ignore
+
+from ..log import set_logging
+from ..utils import test_connect
+from ..storage import templates
+
+logger = logging.getLogger('itchat')
+
+def load_register(core):
+ core.auto_login = auto_login
+ core.configured_reply = configured_reply
+ core.msg_register = msg_register
+ core.run = run
+
+async def auto_login(self, EventScanPayload=None,ScanStatus=None,event_stream=None,
+ hotReload=True, statusStorageDir='itchat.pkl',
+ enableCmdQR=False, picDir=None, qrCallback=None,
+ loginCallback=None, exitCallback=None):
+ if not test_connect():
+ logger.info("You can't get access to internet or wechat domain, so exit.")
+ sys.exit()
+ self.useHotReload = hotReload
+ self.hotReloadDir = statusStorageDir
+ if hotReload:
+ if await self.load_login_status(statusStorageDir,
+ loginCallback=loginCallback, exitCallback=exitCallback):
+ return
+ await self.login(enableCmdQR=enableCmdQR, picDir=picDir, qrCallback=qrCallback, EventScanPayload=EventScanPayload, ScanStatus=ScanStatus, event_stream=event_stream,
+ loginCallback=loginCallback, exitCallback=exitCallback)
+ await self.dump_login_status(statusStorageDir)
+ else:
+ await self.login(enableCmdQR=enableCmdQR, picDir=picDir, qrCallback=qrCallback, EventScanPayload=EventScanPayload, ScanStatus=ScanStatus, event_stream=event_stream,
+ loginCallback=loginCallback, exitCallback=exitCallback)
+
+async def configured_reply(self, event_stream, payload, message_container):
+ ''' determine the type of message and reply if its method is defined
+ however, I use a strange way to determine whether a msg is from massive platform
+ I haven't found a better solution here
+ The main problem I'm worrying about is the mismatching of new friends added on phone
+ If you have any good idea, pleeeease report an issue. I will be more than grateful.
+ '''
+ try:
+ msg = self.msgList.get(timeout=1)
+ if 'MsgId' in msg.keys():
+ message_container[msg['MsgId']] = msg
+ except Queue.Empty:
+ pass
+ else:
+ if isinstance(msg['User'], templates.User):
+ replyFn = self.functionDict['FriendChat'].get(msg['Type'])
+ elif isinstance(msg['User'], templates.MassivePlatform):
+ replyFn = self.functionDict['MpChat'].get(msg['Type'])
+ elif isinstance(msg['User'], templates.Chatroom):
+ replyFn = self.functionDict['GroupChat'].get(msg['Type'])
+ if replyFn is None:
+ r = None
+ else:
+ try:
+ r = await replyFn(msg)
+ if r is not None:
+ await self.send(r, msg.get('FromUserName'))
+ except:
+ logger.warning(traceback.format_exc())
+
+def msg_register(self, msgType, isFriendChat=False, isGroupChat=False, isMpChat=False):
+ ''' a decorator constructor
+ return a specific decorator based on information given '''
+ if not (isinstance(msgType, list) or isinstance(msgType, tuple)):
+ msgType = [msgType]
+ def _msg_register(fn):
+ for _msgType in msgType:
+ if isFriendChat:
+ self.functionDict['FriendChat'][_msgType] = fn
+ if isGroupChat:
+ self.functionDict['GroupChat'][_msgType] = fn
+ if isMpChat:
+ self.functionDict['MpChat'][_msgType] = fn
+ if not any((isFriendChat, isGroupChat, isMpChat)):
+ self.functionDict['FriendChat'][_msgType] = fn
+ return fn
+ return _msg_register
+
+async def run(self, debug=False, blockThread=True):
+ logger.info('Start auto replying.')
+ if debug:
+ set_logging(loggingLevel=logging.DEBUG)
+ async def reply_fn():
+ try:
+ while self.alive:
+ await self.configured_reply()
+ except KeyboardInterrupt:
+ if self.useHotReload:
+ await self.dump_login_status()
+ self.alive = False
+ logger.debug('itchat received an ^C and exit.')
+ logger.info('Bye~')
+ if blockThread:
+ await reply_fn()
+ else:
+ replyThread = threading.Thread(target=reply_fn)
+ replyThread.setDaemon(True)
+ replyThread.start()
diff --git a/lib/itchat/components/__init__.py b/lib/itchat/components/__init__.py
new file mode 100644
index 0000000..0fc321c
--- /dev/null
+++ b/lib/itchat/components/__init__.py
@@ -0,0 +1,12 @@
+from .contact import load_contact
+from .hotreload import load_hotreload
+from .login import load_login
+from .messages import load_messages
+from .register import load_register
+
+def load_components(core):
+ load_contact(core)
+ load_hotreload(core)
+ load_login(core)
+ load_messages(core)
+ load_register(core)
diff --git a/lib/itchat/components/contact.py b/lib/itchat/components/contact.py
new file mode 100644
index 0000000..93e3d16
--- /dev/null
+++ b/lib/itchat/components/contact.py
@@ -0,0 +1,519 @@
+import time
+import re
+import io
+import json
+import copy
+import logging
+
+from .. import config, utils
+from ..returnvalues import ReturnValue
+from ..storage import contact_change
+from ..utils import update_info_dict
+
+logger = logging.getLogger('itchat')
+
+
+def load_contact(core):
+ core.update_chatroom = update_chatroom
+ core.update_friend = update_friend
+ core.get_contact = get_contact
+ core.get_friends = get_friends
+ core.get_chatrooms = get_chatrooms
+ core.get_mps = get_mps
+ core.set_alias = set_alias
+ core.set_pinned = set_pinned
+ core.accept_friend = accept_friend
+ core.get_head_img = get_head_img
+ core.create_chatroom = create_chatroom
+ core.set_chatroom_name = set_chatroom_name
+ core.delete_member_from_chatroom = delete_member_from_chatroom
+ core.add_member_into_chatroom = add_member_into_chatroom
+
+
+def update_chatroom(self, userName, detailedMember=False):
+ if not isinstance(userName, list):
+ userName = [userName]
+ url = '%s/webwxbatchgetcontact?type=ex&r=%s' % (
+ self.loginInfo['url'], int(time.time()))
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT}
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Count': len(userName),
+ 'List': [{
+ 'UserName': u,
+ 'ChatRoomId': '', } for u in userName], }
+ chatroomList = json.loads(self.s.post(url, data=json.dumps(data), headers=headers
+ ).content.decode('utf8', 'replace')).get('ContactList')
+ if not chatroomList:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No chatroom found',
+ 'Ret': -1001, }})
+
+ if detailedMember:
+ def get_detailed_member_info(encryChatroomId, memberList):
+ url = '%s/webwxbatchgetcontact?type=ex&r=%s' % (
+ self.loginInfo['url'], int(time.time()))
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT, }
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Count': len(memberList),
+ 'List': [{
+ 'UserName': member['UserName'],
+ 'EncryChatRoomId': encryChatroomId}
+ for member in memberList], }
+ return json.loads(self.s.post(url, data=json.dumps(data), headers=headers
+ ).content.decode('utf8', 'replace'))['ContactList']
+ MAX_GET_NUMBER = 50
+ for chatroom in chatroomList:
+ totalMemberList = []
+ for i in range(int(len(chatroom['MemberList']) / MAX_GET_NUMBER + 1)):
+ memberList = chatroom['MemberList'][i *
+ MAX_GET_NUMBER: (i+1)*MAX_GET_NUMBER]
+ totalMemberList += get_detailed_member_info(
+ chatroom['EncryChatRoomId'], memberList)
+ chatroom['MemberList'] = totalMemberList
+
+ update_local_chatrooms(self, chatroomList)
+ r = [self.storageClass.search_chatrooms(userName=c['UserName'])
+ for c in chatroomList]
+ return r if 1 < len(r) else r[0]
+
+
+def update_friend(self, userName):
+ if not isinstance(userName, list):
+ userName = [userName]
+ url = '%s/webwxbatchgetcontact?type=ex&r=%s' % (
+ self.loginInfo['url'], int(time.time()))
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT}
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Count': len(userName),
+ 'List': [{
+ 'UserName': u,
+ 'EncryChatRoomId': '', } for u in userName], }
+ friendList = json.loads(self.s.post(url, data=json.dumps(data), headers=headers
+ ).content.decode('utf8', 'replace')).get('ContactList')
+
+ update_local_friends(self, friendList)
+ r = [self.storageClass.search_friends(userName=f['UserName'])
+ for f in friendList]
+ return r if len(r) != 1 else r[0]
+
+
+@contact_change
+def update_local_chatrooms(core, l):
+ '''
+ get a list of chatrooms for updating local chatrooms
+ return a list of given chatrooms with updated info
+ '''
+ for chatroom in l:
+ # format new chatrooms
+ utils.emoji_formatter(chatroom, 'NickName')
+ for member in chatroom['MemberList']:
+ if 'NickName' in member:
+ utils.emoji_formatter(member, 'NickName')
+ if 'DisplayName' in member:
+ utils.emoji_formatter(member, 'DisplayName')
+ if 'RemarkName' in member:
+ utils.emoji_formatter(member, 'RemarkName')
+ # update it to old chatrooms
+ oldChatroom = utils.search_dict_list(
+ core.chatroomList, 'UserName', chatroom['UserName'])
+ if oldChatroom:
+ update_info_dict(oldChatroom, chatroom)
+ # - update other values
+ memberList = chatroom.get('MemberList', [])
+ oldMemberList = oldChatroom['MemberList']
+ if memberList:
+ for member in memberList:
+ oldMember = utils.search_dict_list(
+ oldMemberList, 'UserName', member['UserName'])
+ if oldMember:
+ update_info_dict(oldMember, member)
+ else:
+ oldMemberList.append(member)
+ else:
+ core.chatroomList.append(chatroom)
+ oldChatroom = utils.search_dict_list(
+ core.chatroomList, 'UserName', chatroom['UserName'])
+ # delete useless members
+ if len(chatroom['MemberList']) != len(oldChatroom['MemberList']) and \
+ chatroom['MemberList']:
+ existsUserNames = [member['UserName']
+ for member in chatroom['MemberList']]
+ delList = []
+ for i, member in enumerate(oldChatroom['MemberList']):
+ if member['UserName'] not in existsUserNames:
+ delList.append(i)
+ delList.sort(reverse=True)
+ for i in delList:
+ del oldChatroom['MemberList'][i]
+ # - update OwnerUin
+ if oldChatroom.get('ChatRoomOwner') and oldChatroom.get('MemberList'):
+ owner = utils.search_dict_list(oldChatroom['MemberList'],
+ 'UserName', oldChatroom['ChatRoomOwner'])
+ oldChatroom['OwnerUin'] = (owner or {}).get('Uin', 0)
+ # - update IsAdmin
+ if 'OwnerUin' in oldChatroom and oldChatroom['OwnerUin'] != 0:
+ oldChatroom['IsAdmin'] = \
+ oldChatroom['OwnerUin'] == int(core.loginInfo['wxuin'])
+ else:
+ oldChatroom['IsAdmin'] = None
+ # - update Self
+ newSelf = utils.search_dict_list(oldChatroom['MemberList'],
+ 'UserName', core.storageClass.userName)
+ oldChatroom['Self'] = newSelf or copy.deepcopy(core.loginInfo['User'])
+ return {
+ 'Type': 'System',
+ 'Text': [chatroom['UserName'] for chatroom in l],
+ 'SystemInfo': 'chatrooms',
+ 'FromUserName': core.storageClass.userName,
+ 'ToUserName': core.storageClass.userName, }
+
+
+@contact_change
+def update_local_friends(core, l):
+ '''
+ get a list of friends or mps for updating local contact
+ '''
+ fullList = core.memberList + core.mpList
+ for friend in l:
+ if 'NickName' in friend:
+ utils.emoji_formatter(friend, 'NickName')
+ if 'DisplayName' in friend:
+ utils.emoji_formatter(friend, 'DisplayName')
+ if 'RemarkName' in friend:
+ utils.emoji_formatter(friend, 'RemarkName')
+ oldInfoDict = utils.search_dict_list(
+ fullList, 'UserName', friend['UserName'])
+ if oldInfoDict is None:
+ oldInfoDict = copy.deepcopy(friend)
+ if oldInfoDict['VerifyFlag'] & 8 == 0:
+ core.memberList.append(oldInfoDict)
+ else:
+ core.mpList.append(oldInfoDict)
+ else:
+ update_info_dict(oldInfoDict, friend)
+
+
+@contact_change
+def update_local_uin(core, msg):
+ '''
+ content contains uins and StatusNotifyUserName contains username
+ they are in same order, so what I do is to pair them together
+
+ I caught an exception in this method while not knowing why
+ but don't worry, it won't cause any problem
+ '''
+ uins = re.search('([^<]*?)<', msg['Content'])
+ usernameChangedList = []
+ r = {
+ 'Type': 'System',
+ 'Text': usernameChangedList,
+ 'SystemInfo': 'uins', }
+ if uins:
+ uins = uins.group(1).split(',')
+ usernames = msg['StatusNotifyUserName'].split(',')
+ if 0 < len(uins) == len(usernames):
+ for uin, username in zip(uins, usernames):
+ if not '@' in username:
+ continue
+ fullContact = core.memberList + core.chatroomList + core.mpList
+ userDicts = utils.search_dict_list(fullContact,
+ 'UserName', username)
+ if userDicts:
+ if userDicts.get('Uin', 0) == 0:
+ userDicts['Uin'] = uin
+ usernameChangedList.append(username)
+ logger.debug('Uin fetched: %s, %s' % (username, uin))
+ else:
+ if userDicts['Uin'] != uin:
+ logger.debug('Uin changed: %s, %s' % (
+ userDicts['Uin'], uin))
+ else:
+ if '@@' in username:
+ core.storageClass.updateLock.release()
+ update_chatroom(core, username)
+ core.storageClass.updateLock.acquire()
+ newChatroomDict = utils.search_dict_list(
+ core.chatroomList, 'UserName', username)
+ if newChatroomDict is None:
+ newChatroomDict = utils.struct_friend_info({
+ 'UserName': username,
+ 'Uin': uin,
+ 'Self': copy.deepcopy(core.loginInfo['User'])})
+ core.chatroomList.append(newChatroomDict)
+ else:
+ newChatroomDict['Uin'] = uin
+ elif '@' in username:
+ core.storageClass.updateLock.release()
+ update_friend(core, username)
+ core.storageClass.updateLock.acquire()
+ newFriendDict = utils.search_dict_list(
+ core.memberList, 'UserName', username)
+ if newFriendDict is None:
+ newFriendDict = utils.struct_friend_info({
+ 'UserName': username,
+ 'Uin': uin, })
+ core.memberList.append(newFriendDict)
+ else:
+ newFriendDict['Uin'] = uin
+ usernameChangedList.append(username)
+ logger.debug('Uin fetched: %s, %s' % (username, uin))
+ else:
+ logger.debug('Wrong length of uins & usernames: %s, %s' % (
+ len(uins), len(usernames)))
+ else:
+ logger.debug('No uins in 51 message')
+ logger.debug(msg['Content'])
+ return r
+
+
+def get_contact(self, update=False):
+ if not update:
+ return utils.contact_deep_copy(self, self.chatroomList)
+
+ def _get_contact(seq=0):
+ url = '%s/webwxgetcontact?r=%s&seq=%s&skey=%s' % (self.loginInfo['url'],
+ int(time.time()), seq, self.loginInfo['skey'])
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT, }
+ try:
+ r = self.s.get(url, headers=headers)
+ except:
+ logger.info(
+ 'Failed to fetch contact, that may because of the amount of your chatrooms')
+ for chatroom in self.get_chatrooms():
+ self.update_chatroom(chatroom['UserName'], detailedMember=True)
+ return 0, []
+ j = json.loads(r.content.decode('utf-8', 'replace'))
+ return j.get('Seq', 0), j.get('MemberList')
+ seq, memberList = 0, []
+ while 1:
+ seq, batchMemberList = _get_contact(seq)
+ memberList.extend(batchMemberList)
+ if seq == 0:
+ break
+ chatroomList, otherList = [], []
+ for m in memberList:
+ if m['Sex'] != 0:
+ otherList.append(m)
+ elif '@@' in m['UserName']:
+ chatroomList.append(m)
+ elif '@' in m['UserName']:
+ # mp will be dealt in update_local_friends as well
+ otherList.append(m)
+ if chatroomList:
+ update_local_chatrooms(self, chatroomList)
+ if otherList:
+ update_local_friends(self, otherList)
+ return utils.contact_deep_copy(self, chatroomList)
+
+
+def get_friends(self, update=False):
+ if update:
+ self.get_contact(update=True)
+ return utils.contact_deep_copy(self, self.memberList)
+
+
+def get_chatrooms(self, update=False, contactOnly=False):
+ if contactOnly:
+ return self.get_contact(update=True)
+ else:
+ if update:
+ self.get_contact(True)
+ return utils.contact_deep_copy(self, self.chatroomList)
+
+
+def get_mps(self, update=False):
+ if update:
+ self.get_contact(update=True)
+ return utils.contact_deep_copy(self, self.mpList)
+
+
+def set_alias(self, userName, alias):
+ oldFriendInfo = utils.search_dict_list(
+ self.memberList, 'UserName', userName)
+ if oldFriendInfo is None:
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1001, }})
+ url = '%s/webwxoplog?lang=%s&pass_ticket=%s' % (
+ self.loginInfo['url'], 'zh_CN', self.loginInfo['pass_ticket'])
+ data = {
+ 'UserName': userName,
+ 'CmdId': 2,
+ 'RemarkName': alias,
+ 'BaseRequest': self.loginInfo['BaseRequest'], }
+ headers = {'User-Agent': config.USER_AGENT}
+ r = self.s.post(url, json.dumps(data, ensure_ascii=False).encode('utf8'),
+ headers=headers)
+ r = ReturnValue(rawResponse=r)
+ if r:
+ oldFriendInfo['RemarkName'] = alias
+ return r
+
+
+def set_pinned(self, userName, isPinned=True):
+ url = '%s/webwxoplog?pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'])
+ data = {
+ 'UserName': userName,
+ 'CmdId': 3,
+ 'OP': int(isPinned),
+ 'BaseRequest': self.loginInfo['BaseRequest'], }
+ headers = {'User-Agent': config.USER_AGENT}
+ r = self.s.post(url, json=data, headers=headers)
+ return ReturnValue(rawResponse=r)
+
+
+def accept_friend(self, userName, v4='', autoUpdate=True):
+ url = f"{self.loginInfo['url']}/webwxverifyuser?r={int(time.time())}&pass_ticket={self.loginInfo['pass_ticket']}"
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Opcode': 3, # 3
+ 'VerifyUserListSize': 1,
+ 'VerifyUserList': [{
+ 'Value': userName,
+ 'VerifyUserTicket': v4, }],
+ 'VerifyContent': '',
+ 'SceneListCount': 1,
+ 'SceneList': [33],
+ 'skey': self.loginInfo['skey'], }
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT}
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8', 'replace'))
+ if autoUpdate:
+ self.update_friend(userName)
+ return ReturnValue(rawResponse=r)
+
+
+def get_head_img(self, userName=None, chatroomUserName=None, picDir=None):
+ ''' get head image
+ * if you want to get chatroom header: only set chatroomUserName
+ * if you want to get friend header: only set userName
+ * if you want to get chatroom member header: set both
+ '''
+ params = {
+ 'userName': userName or chatroomUserName or self.storageClass.userName,
+ 'skey': self.loginInfo['skey'],
+ 'type': 'big', }
+ url = '%s/webwxgeticon' % self.loginInfo['url']
+ if chatroomUserName is None:
+ infoDict = self.storageClass.search_friends(userName=userName)
+ if infoDict is None:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No friend found',
+ 'Ret': -1001, }})
+ else:
+ if userName is None:
+ url = '%s/webwxgetheadimg' % self.loginInfo['url']
+ else:
+ chatroom = self.storageClass.search_chatrooms(
+ userName=chatroomUserName)
+ if chatroomUserName is None:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No chatroom found',
+ 'Ret': -1001, }})
+ if 'EncryChatRoomId' in chatroom:
+ params['chatroomid'] = chatroom['EncryChatRoomId']
+ params['chatroomid'] = params.get(
+ 'chatroomid') or chatroom['UserName']
+ headers = {'User-Agent': config.USER_AGENT}
+ r = self.s.get(url, params=params, stream=True, headers=headers)
+ tempStorage = io.BytesIO()
+ for block in r.iter_content(1024):
+ tempStorage.write(block)
+ if picDir is None:
+ return tempStorage.getvalue()
+ with open(picDir, 'wb') as f:
+ f.write(tempStorage.getvalue())
+ tempStorage.seek(0)
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Successfully downloaded',
+ 'Ret': 0, },
+ 'PostFix': utils.get_image_postfix(tempStorage.read(20)), })
+
+
+def create_chatroom(self, memberList, topic=''):
+ url = '%s/webwxcreatechatroom?pass_ticket=%s&r=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'], int(time.time()))
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'MemberCount': len(memberList.split(',')),
+ 'MemberList': [{'UserName': member} for member in memberList.split(',')],
+ 'Topic': topic, }
+ headers = {
+ 'content-type': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT}
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8', 'ignore'))
+ return ReturnValue(rawResponse=r)
+
+
+def set_chatroom_name(self, chatroomUserName, name):
+ url = '%s/webwxupdatechatroom?fun=modtopic&pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'])
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'ChatRoomName': chatroomUserName,
+ 'NewTopic': name, }
+ headers = {
+ 'content-type': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT}
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8', 'ignore'))
+ return ReturnValue(rawResponse=r)
+
+
+def delete_member_from_chatroom(self, chatroomUserName, memberList):
+ url = '%s/webwxupdatechatroom?fun=delmember&pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'])
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'ChatRoomName': chatroomUserName,
+ 'DelMemberList': ','.join([member['UserName'] for member in memberList]), }
+ headers = {
+ 'content-type': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT}
+ r = self.s.post(url, data=json.dumps(data), headers=headers)
+ return ReturnValue(rawResponse=r)
+
+
+def add_member_into_chatroom(self, chatroomUserName, memberList,
+ useInvitation=False):
+ ''' add or invite member into chatroom
+ * there are two ways to get members into chatroom: invite or directly add
+ * but for chatrooms with more than 40 users, you can only use invite
+ * but don't worry we will auto-force userInvitation for you when necessary
+ '''
+ if not useInvitation:
+ chatroom = self.storageClass.search_chatrooms(
+ userName=chatroomUserName)
+ if not chatroom:
+ chatroom = self.update_chatroom(chatroomUserName)
+ if len(chatroom['MemberList']) > self.loginInfo['InviteStartCount']:
+ useInvitation = True
+ if useInvitation:
+ fun, memberKeyName = 'invitemember', 'InviteMemberList'
+ else:
+ fun, memberKeyName = 'addmember', 'AddMemberList'
+ url = '%s/webwxupdatechatroom?fun=%s&pass_ticket=%s' % (
+ self.loginInfo['url'], fun, self.loginInfo['pass_ticket'])
+ params = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'ChatRoomName': chatroomUserName,
+ memberKeyName: memberList, }
+ headers = {
+ 'content-type': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT}
+ r = self.s.post(url, data=json.dumps(params), headers=headers)
+ return ReturnValue(rawResponse=r)
diff --git a/lib/itchat/components/hotreload.py b/lib/itchat/components/hotreload.py
new file mode 100644
index 0000000..1003c67
--- /dev/null
+++ b/lib/itchat/components/hotreload.py
@@ -0,0 +1,102 @@
+import pickle, os
+import logging
+
+import requests
+
+from ..config import VERSION
+from ..returnvalues import ReturnValue
+from ..storage import templates
+from .contact import update_local_chatrooms, update_local_friends
+from .messages import produce_msg
+
+logger = logging.getLogger('itchat')
+
+def load_hotreload(core):
+ core.dump_login_status = dump_login_status
+ core.load_login_status = load_login_status
+
+def dump_login_status(self, fileDir=None):
+ fileDir = fileDir or self.hotReloadDir
+ try:
+ with open(fileDir, 'w') as f:
+ f.write('itchat - DELETE THIS')
+ os.remove(fileDir)
+ except:
+ raise Exception('Incorrect fileDir')
+ status = {
+ 'version' : VERSION,
+ 'loginInfo' : self.loginInfo,
+ 'cookies' : self.s.cookies.get_dict(),
+ 'storage' : self.storageClass.dumps()}
+ with open(fileDir, 'wb') as f:
+ pickle.dump(status, f)
+ logger.debug('Dump login status for hot reload successfully.')
+
+def load_login_status(self, fileDir,
+ loginCallback=None, exitCallback=None):
+ try:
+ with open(fileDir, 'rb') as f:
+ j = pickle.load(f)
+ except Exception as e:
+ logger.debug('No such file, loading login status failed.')
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No such file, loading login status failed.',
+ 'Ret': -1002, }})
+
+ if j.get('version', '') != VERSION:
+ logger.debug(('you have updated itchat from %s to %s, ' +
+ 'so cached status is ignored') % (
+ j.get('version', 'old version'), VERSION))
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'cached status ignored because of version',
+ 'Ret': -1005, }})
+ self.loginInfo = j['loginInfo']
+ self.loginInfo['User'] = templates.User(self.loginInfo['User'])
+ self.loginInfo['User'].core = self
+ self.s.cookies = requests.utils.cookiejar_from_dict(j['cookies'])
+ self.storageClass.loads(j['storage'])
+ try:
+ msgList, contactList = self.get_msg()
+ except:
+ msgList = contactList = None
+ if (msgList or contactList) is None:
+ self.logout()
+ load_last_login_status(self.s, j['cookies'])
+ logger.debug('server refused, loading login status failed.')
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'server refused, loading login status failed.',
+ 'Ret': -1003, }})
+ else:
+ if contactList:
+ for contact in contactList:
+ if '@@' in contact['UserName']:
+ update_local_chatrooms(self, [contact])
+ else:
+ update_local_friends(self, [contact])
+ if msgList:
+ msgList = produce_msg(self, msgList)
+ for msg in msgList: self.msgList.put(msg)
+ self.start_receiving(exitCallback)
+ logger.debug('loading login status succeeded.')
+ if hasattr(loginCallback, '__call__'):
+ loginCallback()
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'loading login status succeeded.',
+ 'Ret': 0, }})
+
+def load_last_login_status(session, cookiesDict):
+ try:
+ session.cookies = requests.utils.cookiejar_from_dict({
+ 'webwxuvid': cookiesDict['webwxuvid'],
+ 'webwx_auth_ticket': cookiesDict['webwx_auth_ticket'],
+ 'login_frequency': '2',
+ 'last_wxuin': cookiesDict['wxuin'],
+ 'wxloadtime': cookiesDict['wxloadtime'] + '_expired',
+ 'wxpluginkey': cookiesDict['wxloadtime'],
+ 'wxuin': cookiesDict['wxuin'],
+ 'mm_lang': 'zh_CN',
+ 'MM_WX_NOTIFY_STATE': '1',
+ 'MM_WX_SOUND_STATE': '1', })
+ except:
+ logger.info('Load status for push login failed, we may have experienced a cookies change.')
+ logger.info('If you are using the newest version of itchat, you may report a bug.')
diff --git a/lib/itchat/components/login.py b/lib/itchat/components/login.py
new file mode 100644
index 0000000..a2dd17c
--- /dev/null
+++ b/lib/itchat/components/login.py
@@ -0,0 +1,418 @@
+import os
+import time
+import re
+import io
+import threading
+import json
+import xml.dom.minidom
+import random
+import traceback
+import logging
+try:
+ from httplib import BadStatusLine
+except ImportError:
+ from http.client import BadStatusLine
+
+import requests
+from pyqrcode import QRCode
+
+from .. import config, utils
+from ..returnvalues import ReturnValue
+from ..storage.templates import wrap_user_dict
+from .contact import update_local_chatrooms, update_local_friends
+from .messages import produce_msg
+
+logger = logging.getLogger('itchat')
+
+
+def load_login(core):
+ core.login = login
+ core.get_QRuuid = get_QRuuid
+ core.get_QR = get_QR
+ core.check_login = check_login
+ core.web_init = web_init
+ core.show_mobile_login = show_mobile_login
+ core.start_receiving = start_receiving
+ core.get_msg = get_msg
+ core.logout = logout
+
+
+def login(self, enableCmdQR=False, picDir=None, qrCallback=None,
+ loginCallback=None, exitCallback=None):
+ if self.alive or self.isLogging:
+ logger.warning('itchat has already logged in.')
+ return
+ self.isLogging = True
+ logger.info('Ready to login.')
+ while self.isLogging:
+ uuid = push_login(self)
+ if uuid:
+ qrStorage = io.BytesIO()
+ else:
+ logger.info('Getting uuid of QR code.')
+ while not self.get_QRuuid():
+ time.sleep(1)
+ logger.info('Downloading QR code.')
+ qrStorage = self.get_QR(enableCmdQR=enableCmdQR,
+ picDir=picDir, qrCallback=qrCallback)
+ # logger.info('Please scan the QR code to log in.')
+ isLoggedIn = False
+ while not isLoggedIn:
+ status = self.check_login()
+ if hasattr(qrCallback, '__call__'):
+ qrCallback(uuid=self.uuid, status=status,
+ qrcode=qrStorage.getvalue())
+ if status == '200':
+ isLoggedIn = True
+ elif status == '201':
+ if isLoggedIn is not None:
+ logger.info('Please press confirm on your phone.')
+ isLoggedIn = None
+ time.sleep(7)
+ time.sleep(0.5)
+ elif status != '408':
+ break
+ if isLoggedIn:
+ break
+ elif self.isLogging:
+ logger.info('Log in time out, reloading QR code.')
+ else:
+ return # log in process is stopped by user
+ logger.info('Loading the contact, this may take a little while.')
+ self.web_init()
+ self.show_mobile_login()
+ self.get_contact(True)
+ if hasattr(loginCallback, '__call__'):
+ r = loginCallback()
+ else:
+ # utils.clear_screen()
+ if os.path.exists(picDir or config.DEFAULT_QR):
+ os.remove(picDir or config.DEFAULT_QR)
+ logger.info('Login successfully as %s' % self.storageClass.nickName)
+ self.start_receiving(exitCallback)
+ self.isLogging = False
+
+
+def push_login(core):
+ cookiesDict = core.s.cookies.get_dict()
+ if 'wxuin' in cookiesDict:
+ url = '%s/cgi-bin/mmwebwx-bin/webwxpushloginurl?uin=%s' % (
+ config.BASE_URL, cookiesDict['wxuin'])
+ headers = {'User-Agent': config.USER_AGENT}
+ r = core.s.get(url, headers=headers).json()
+ if 'uuid' in r and r.get('ret') in (0, '0'):
+ core.uuid = r['uuid']
+ return r['uuid']
+ return False
+
+
+def get_QRuuid(self):
+ url = '%s/jslogin' % config.BASE_URL
+ params = {
+ 'appid': 'wx782c26e4c19acffb',
+ 'fun': 'new',
+ 'redirect_uri': 'https://wx.qq.com/cgi-bin/mmwebwx-bin/webwxnewloginpage?mod=desktop',
+ 'lang': 'zh_CN'}
+ headers = {'User-Agent': config.USER_AGENT}
+ r = self.s.get(url, params=params, headers=headers)
+ regx = r'window.QRLogin.code = (\d+); window.QRLogin.uuid = "(\S+?)";'
+ data = re.search(regx, r.text)
+ if data and data.group(1) == '200':
+ self.uuid = data.group(2)
+ return self.uuid
+
+
+def get_QR(self, uuid=None, enableCmdQR=False, picDir=None, qrCallback=None):
+ uuid = uuid or self.uuid
+ picDir = picDir or config.DEFAULT_QR
+ qrStorage = io.BytesIO()
+ qrCode = QRCode('https://login.weixin.qq.com/l/' + uuid)
+ qrCode.png(qrStorage, scale=10)
+ if hasattr(qrCallback, '__call__'):
+ qrCallback(uuid=uuid, status='0', qrcode=qrStorage.getvalue())
+ else:
+ with open(picDir, 'wb') as f:
+ f.write(qrStorage.getvalue())
+ if enableCmdQR:
+ utils.print_cmd_qr(qrCode.text(1), enableCmdQR=enableCmdQR)
+ else:
+ utils.print_qr(picDir)
+ return qrStorage
+
+
+def check_login(self, uuid=None):
+ uuid = uuid or self.uuid
+ url = '%s/cgi-bin/mmwebwx-bin/login' % config.BASE_URL
+ localTime = int(time.time())
+ params = 'loginicon=true&uuid=%s&tip=1&r=%s&_=%s' % (
+ uuid, int(-localTime / 1579), localTime)
+ headers = {'User-Agent': config.USER_AGENT}
+ r = self.s.get(url, params=params, headers=headers)
+ regx = r'window.code=(\d+)'
+ data = re.search(regx, r.text)
+ if data and data.group(1) == '200':
+ if process_login_info(self, r.text):
+ return '200'
+ else:
+ return '400'
+ elif data:
+ return data.group(1)
+ else:
+ return '400'
+
+
+def process_login_info(core, loginContent):
+ ''' when finish login (scanning qrcode)
+ * syncUrl and fileUploadingUrl will be fetched
+ * deviceid and msgid will be generated
+ * skey, wxsid, wxuin, pass_ticket will be fetched
+ '''
+ regx = r'window.redirect_uri="(\S+)";'
+ core.loginInfo['url'] = re.search(regx, loginContent).group(1)
+ headers = {'User-Agent': config.USER_AGENT,
+ 'client-version': config.UOS_PATCH_CLIENT_VERSION,
+ 'extspam': config.UOS_PATCH_EXTSPAM,
+ 'referer': 'https://wx.qq.com/?&lang=zh_CN&target=t'
+ }
+ r = core.s.get(core.loginInfo['url'],
+ headers=headers, allow_redirects=False)
+ core.loginInfo['url'] = core.loginInfo['url'][:core.loginInfo['url'].rfind(
+ '/')]
+ for indexUrl, detailedUrl in (
+ ("wx2.qq.com", ("file.wx2.qq.com", "webpush.wx2.qq.com")),
+ ("wx8.qq.com", ("file.wx8.qq.com", "webpush.wx8.qq.com")),
+ ("qq.com", ("file.wx.qq.com", "webpush.wx.qq.com")),
+ ("web2.wechat.com", ("file.web2.wechat.com", "webpush.web2.wechat.com")),
+ ("wechat.com", ("file.web.wechat.com", "webpush.web.wechat.com"))):
+ fileUrl, syncUrl = ['https://%s/cgi-bin/mmwebwx-bin' %
+ url for url in detailedUrl]
+ if indexUrl in core.loginInfo['url']:
+ core.loginInfo['fileUrl'], core.loginInfo['syncUrl'] = \
+ fileUrl, syncUrl
+ break
+ else:
+ core.loginInfo['fileUrl'] = core.loginInfo['syncUrl'] = core.loginInfo['url']
+ core.loginInfo['deviceid'] = 'e' + repr(random.random())[2:17]
+ core.loginInfo['logintime'] = int(time.time() * 1e3)
+ core.loginInfo['BaseRequest'] = {}
+ cookies = core.s.cookies.get_dict()
+ res = re.findall('(.*?)', r.text, re.S)
+ skey = res[0] if res else None
+ res = re.findall(
+ '(.*?)', r.text, re.S)
+ pass_ticket = res[0] if res else None
+ if skey is not None:
+ core.loginInfo['skey'] = core.loginInfo['BaseRequest']['Skey'] = skey
+ core.loginInfo['wxsid'] = core.loginInfo['BaseRequest']['Sid'] = cookies["wxsid"]
+ core.loginInfo['wxuin'] = core.loginInfo['BaseRequest']['Uin'] = cookies["wxuin"]
+ if pass_ticket is not None:
+ core.loginInfo['pass_ticket'] = pass_ticket
+ # A question : why pass_ticket == DeviceID ?
+ # deviceID is only a randomly generated number
+
+ # UOS PATCH By luvletter2333, Sun Feb 28 10:00 PM
+ # for node in xml.dom.minidom.parseString(r.text).documentElement.childNodes:
+ # if node.nodeName == 'skey':
+ # core.loginInfo['skey'] = core.loginInfo['BaseRequest']['Skey'] = node.childNodes[0].data
+ # elif node.nodeName == 'wxsid':
+ # core.loginInfo['wxsid'] = core.loginInfo['BaseRequest']['Sid'] = node.childNodes[0].data
+ # elif node.nodeName == 'wxuin':
+ # core.loginInfo['wxuin'] = core.loginInfo['BaseRequest']['Uin'] = node.childNodes[0].data
+ # elif node.nodeName == 'pass_ticket':
+ # core.loginInfo['pass_ticket'] = core.loginInfo['BaseRequest']['DeviceID'] = node.childNodes[0].data
+ if not all([key in core.loginInfo for key in ('skey', 'wxsid', 'wxuin', 'pass_ticket')]):
+ logger.error(
+ 'Your wechat account may be LIMITED to log in WEB wechat, error info:\n%s' % r.text)
+ core.isLogging = False
+ return False
+ return True
+
+
+def web_init(self):
+ url = '%s/webwxinit' % self.loginInfo['url']
+ params = {
+ 'r': int(-time.time() / 1579),
+ 'pass_ticket': self.loginInfo['pass_ticket'], }
+ data = {'BaseRequest': self.loginInfo['BaseRequest'], }
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT, }
+ r = self.s.post(url, params=params, data=json.dumps(data), headers=headers)
+ dic = json.loads(r.content.decode('utf-8', 'replace'))
+ # deal with login info
+ utils.emoji_formatter(dic['User'], 'NickName')
+ self.loginInfo['InviteStartCount'] = int(dic['InviteStartCount'])
+ self.loginInfo['User'] = wrap_user_dict(
+ utils.struct_friend_info(dic['User']))
+ self.memberList.append(self.loginInfo['User'])
+ self.loginInfo['SyncKey'] = dic['SyncKey']
+ self.loginInfo['synckey'] = '|'.join(['%s_%s' % (item['Key'], item['Val'])
+ for item in dic['SyncKey']['List']])
+ self.storageClass.userName = dic['User']['UserName']
+ self.storageClass.nickName = dic['User']['NickName']
+ # deal with contact list returned when init
+ contactList = dic.get('ContactList', [])
+ chatroomList, otherList = [], []
+ for m in contactList:
+ if m['Sex'] != 0:
+ otherList.append(m)
+ elif '@@' in m['UserName']:
+ m['MemberList'] = [] # don't let dirty info pollute the list
+ chatroomList.append(m)
+ elif '@' in m['UserName']:
+ # mp will be dealt in update_local_friends as well
+ otherList.append(m)
+ if chatroomList:
+ update_local_chatrooms(self, chatroomList)
+ if otherList:
+ update_local_friends(self, otherList)
+ return dic
+
+
+def show_mobile_login(self):
+ url = '%s/webwxstatusnotify?lang=zh_CN&pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'])
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Code': 3,
+ 'FromUserName': self.storageClass.userName,
+ 'ToUserName': self.storageClass.userName,
+ 'ClientMsgId': int(time.time()), }
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT, }
+ r = self.s.post(url, data=json.dumps(data), headers=headers)
+ return ReturnValue(rawResponse=r)
+
+
+def start_receiving(self, exitCallback=None, getReceivingFnOnly=False):
+ self.alive = True
+
+ def maintain_loop():
+ retryCount = 0
+ while self.alive:
+ try:
+ i = sync_check(self)
+ if i is None:
+ self.alive = False
+ elif i == '0':
+ pass
+ else:
+ msgList, contactList = self.get_msg()
+ if msgList:
+ msgList = produce_msg(self, msgList)
+ for msg in msgList:
+ self.msgList.put(msg)
+ if contactList:
+ chatroomList, otherList = [], []
+ for contact in contactList:
+ if '@@' in contact['UserName']:
+ chatroomList.append(contact)
+ else:
+ otherList.append(contact)
+ chatroomMsg = update_local_chatrooms(
+ self, chatroomList)
+ chatroomMsg['User'] = self.loginInfo['User']
+ self.msgList.put(chatroomMsg)
+ update_local_friends(self, otherList)
+ retryCount = 0
+ except requests.exceptions.ReadTimeout:
+ pass
+ except:
+ retryCount += 1
+ logger.error(traceback.format_exc())
+ if self.receivingRetryCount < retryCount:
+ logger.error("Having tried %s times, but still failed. " % (
+ retryCount) + "Stop trying...")
+ self.alive = False
+ else:
+ time.sleep(1)
+ self.logout()
+ if hasattr(exitCallback, '__call__'):
+ exitCallback()
+ else:
+ logger.info('LOG OUT!')
+ if getReceivingFnOnly:
+ return maintain_loop
+ else:
+ maintainThread = threading.Thread(target=maintain_loop)
+ maintainThread.setDaemon(True)
+ maintainThread.start()
+
+
+def sync_check(self):
+ url = '%s/synccheck' % self.loginInfo.get('syncUrl', self.loginInfo['url'])
+ params = {
+ 'r': int(time.time() * 1000),
+ 'skey': self.loginInfo['skey'],
+ 'sid': self.loginInfo['wxsid'],
+ 'uin': self.loginInfo['wxuin'],
+ 'deviceid': self.loginInfo['deviceid'],
+ 'synckey': self.loginInfo['synckey'],
+ '_': self.loginInfo['logintime'], }
+ headers = {'User-Agent': config.USER_AGENT}
+ self.loginInfo['logintime'] += 1
+ try:
+ r = self.s.get(url, params=params, headers=headers,
+ timeout=config.TIMEOUT)
+ except requests.exceptions.ConnectionError as e:
+ try:
+ if not isinstance(e.args[0].args[1], BadStatusLine):
+ raise
+ # will return a package with status '0 -'
+ # and value like:
+ # 6f:00:8a:9c:09:74:e4:d8:e0:14:bf:96:3a:56:a0:64:1b:a4:25:5d:12:f4:31:a5:30:f1:c6:48:5f:c3:75:6a:99:93
+ # seems like status of typing, but before I make further achievement code will remain like this
+ return '2'
+ except:
+ raise
+ r.raise_for_status()
+ regx = r'window.synccheck={retcode:"(\d+)",selector:"(\d+)"}'
+ pm = re.search(regx, r.text)
+ if pm is None or pm.group(1) != '0':
+ logger.error('Unexpected sync check result: %s' % r.text)
+ return None
+ return pm.group(2)
+
+
+def get_msg(self):
+ self.loginInfo['deviceid'] = 'e' + repr(random.random())[2:17]
+ url = '%s/webwxsync?sid=%s&skey=%s&pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['wxsid'],
+ self.loginInfo['skey'], self.loginInfo['pass_ticket'])
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'SyncKey': self.loginInfo['SyncKey'],
+ 'rr': ~int(time.time()), }
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent': config.USER_AGENT}
+ r = self.s.post(url, data=json.dumps(data),
+ headers=headers, timeout=config.TIMEOUT)
+ dic = json.loads(r.content.decode('utf-8', 'replace'))
+ if dic['BaseResponse']['Ret'] != 0:
+ return None, None
+ self.loginInfo['SyncKey'] = dic['SyncKey']
+ self.loginInfo['synckey'] = '|'.join(['%s_%s' % (item['Key'], item['Val'])
+ for item in dic['SyncCheckKey']['List']])
+ return dic['AddMsgList'], dic['ModContactList']
+
+
+def logout(self):
+ if self.alive:
+ url = '%s/webwxlogout' % self.loginInfo['url']
+ params = {
+ 'redirect': 1,
+ 'type': 1,
+ 'skey': self.loginInfo['skey'], }
+ headers = {'User-Agent': config.USER_AGENT}
+ self.s.get(url, params=params, headers=headers)
+ self.alive = False
+ self.isLogging = False
+ self.s.cookies.clear()
+ del self.chatroomList[:]
+ del self.memberList[:]
+ del self.mpList[:]
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'logout successfully.',
+ 'Ret': 0, }})
diff --git a/lib/itchat/components/messages.py b/lib/itchat/components/messages.py
new file mode 100644
index 0000000..85c0ca2
--- /dev/null
+++ b/lib/itchat/components/messages.py
@@ -0,0 +1,528 @@
+import os, time, re, io
+import json
+import mimetypes, hashlib
+import logging
+from collections import OrderedDict
+
+import requests
+
+from .. import config, utils
+from ..returnvalues import ReturnValue
+from ..storage import templates
+from .contact import update_local_uin
+
+logger = logging.getLogger('itchat')
+
+def load_messages(core):
+ core.send_raw_msg = send_raw_msg
+ core.send_msg = send_msg
+ core.upload_file = upload_file
+ core.send_file = send_file
+ core.send_image = send_image
+ core.send_video = send_video
+ core.send = send
+ core.revoke = revoke
+
+def get_download_fn(core, url, msgId):
+ def download_fn(downloadDir=None):
+ params = {
+ 'msgid': msgId,
+ 'skey': core.loginInfo['skey'],}
+ headers = { 'User-Agent' : config.USER_AGENT }
+ r = core.s.get(url, params=params, stream=True, headers = headers)
+ tempStorage = io.BytesIO()
+ for block in r.iter_content(1024):
+ tempStorage.write(block)
+ if downloadDir is None:
+ return tempStorage.getvalue()
+ with open(downloadDir, 'wb') as f:
+ f.write(tempStorage.getvalue())
+ tempStorage.seek(0)
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Successfully downloaded',
+ 'Ret': 0, },
+ 'PostFix': utils.get_image_postfix(tempStorage.read(20)), })
+ return download_fn
+
+def produce_msg(core, msgList):
+ ''' for messages types
+ * 40 msg, 43 videochat, 50 VOIPMSG, 52 voipnotifymsg
+ * 53 webwxvoipnotifymsg, 9999 sysnotice
+ '''
+ rl = []
+ srl = [40, 43, 50, 52, 53, 9999]
+ for m in msgList:
+ # get actual opposite
+ if m['FromUserName'] == core.storageClass.userName:
+ actualOpposite = m['ToUserName']
+ else:
+ actualOpposite = m['FromUserName']
+ # produce basic message
+ if '@@' in m['FromUserName'] or '@@' in m['ToUserName']:
+ produce_group_chat(core, m)
+ else:
+ utils.msg_formatter(m, 'Content')
+ # set user of msg
+ if '@@' in actualOpposite:
+ m['User'] = core.search_chatrooms(userName=actualOpposite) or \
+ templates.Chatroom({'UserName': actualOpposite})
+ # we don't need to update chatroom here because we have
+ # updated once when producing basic message
+ elif actualOpposite in ('filehelper', 'fmessage'):
+ m['User'] = templates.User({'UserName': actualOpposite})
+ else:
+ m['User'] = core.search_mps(userName=actualOpposite) or \
+ core.search_friends(userName=actualOpposite) or \
+ templates.User(userName=actualOpposite)
+ # by default we think there may be a user missing not a mp
+ m['User'].core = core
+ if m['MsgType'] == 1: # words
+ if m['Url']:
+ regx = r'(.+?\(.+?\))'
+ data = re.search(regx, m['Content'])
+ data = 'Map' if data is None else data.group(1)
+ msg = {
+ 'Type': 'Map',
+ 'Text': data,}
+ else:
+ msg = {
+ 'Type': 'Text',
+ 'Text': m['Content'],}
+ elif m['MsgType'] == 3 or m['MsgType'] == 47: # picture
+ download_fn = get_download_fn(core,
+ '%s/webwxgetmsgimg' % core.loginInfo['url'], m['NewMsgId'])
+ msg = {
+ 'Type' : 'Picture',
+ 'FileName' : '%s.%s' % (time.strftime('%y%m%d-%H%M%S', time.localtime()),
+ 'png' if m['MsgType'] == 3 else 'gif'),
+ 'Text' : download_fn, }
+ elif m['MsgType'] == 34: # voice
+ download_fn = get_download_fn(core,
+ '%s/webwxgetvoice' % core.loginInfo['url'], m['NewMsgId'])
+ msg = {
+ 'Type': 'Recording',
+ 'FileName' : '%s.mp3' % time.strftime('%y%m%d-%H%M%S', time.localtime()),
+ 'Text': download_fn,}
+ elif m['MsgType'] == 37: # friends
+ m['User']['UserName'] = m['RecommendInfo']['UserName']
+ msg = {
+ 'Type': 'Friends',
+ 'Text': {
+ 'status' : m['Status'],
+ 'userName' : m['RecommendInfo']['UserName'],
+ 'verifyContent' : m['Ticket'],
+ 'autoUpdate' : m['RecommendInfo'], }, }
+ m['User'].verifyDict = msg['Text']
+ elif m['MsgType'] == 42: # name card
+ msg = {
+ 'Type': 'Card',
+ 'Text': m['RecommendInfo'], }
+ elif m['MsgType'] in (43, 62): # tiny video
+ msgId = m['MsgId']
+ def download_video(videoDir=None):
+ url = '%s/webwxgetvideo' % core.loginInfo['url']
+ params = {
+ 'msgid': msgId,
+ 'skey': core.loginInfo['skey'],}
+ headers = {'Range': 'bytes=0-', 'User-Agent' : config.USER_AGENT }
+ r = core.s.get(url, params=params, headers=headers, stream=True)
+ tempStorage = io.BytesIO()
+ for block in r.iter_content(1024):
+ tempStorage.write(block)
+ if videoDir is None:
+ return tempStorage.getvalue()
+ with open(videoDir, 'wb') as f:
+ f.write(tempStorage.getvalue())
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Successfully downloaded',
+ 'Ret': 0, }})
+ msg = {
+ 'Type': 'Video',
+ 'FileName' : '%s.mp4' % time.strftime('%y%m%d-%H%M%S', time.localtime()),
+ 'Text': download_video, }
+ elif m['MsgType'] == 49: # sharing
+ if m['AppMsgType'] == 0: # chat history
+ msg = {
+ 'Type': 'Note',
+ 'Text': m['Content'], }
+ elif m['AppMsgType'] == 6:
+ rawMsg = m
+ cookiesList = {name:data for name,data in core.s.cookies.items()}
+ def download_atta(attaDir=None):
+ url = core.loginInfo['fileUrl'] + '/webwxgetmedia'
+ params = {
+ 'sender': rawMsg['FromUserName'],
+ 'mediaid': rawMsg['MediaId'],
+ 'filename': rawMsg['FileName'],
+ 'fromuser': core.loginInfo['wxuin'],
+ 'pass_ticket': 'undefined',
+ 'webwx_data_ticket': cookiesList['webwx_data_ticket'],}
+ headers = { 'User-Agent' : config.USER_AGENT }
+ r = core.s.get(url, params=params, stream=True, headers=headers)
+ tempStorage = io.BytesIO()
+ for block in r.iter_content(1024):
+ tempStorage.write(block)
+ if attaDir is None:
+ return tempStorage.getvalue()
+ with open(attaDir, 'wb') as f:
+ f.write(tempStorage.getvalue())
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Successfully downloaded',
+ 'Ret': 0, }})
+ msg = {
+ 'Type': 'Attachment',
+ 'Text': download_atta, }
+ elif m['AppMsgType'] == 8:
+ download_fn = get_download_fn(core,
+ '%s/webwxgetmsgimg' % core.loginInfo['url'], m['NewMsgId'])
+ msg = {
+ 'Type' : 'Picture',
+ 'FileName' : '%s.gif' % (
+ time.strftime('%y%m%d-%H%M%S', time.localtime())),
+ 'Text' : download_fn, }
+ elif m['AppMsgType'] == 17:
+ msg = {
+ 'Type': 'Note',
+ 'Text': m['FileName'], }
+ elif m['AppMsgType'] == 2000:
+ regx = r'\[CDATA\[(.+?)\][\s\S]+?\[CDATA\[(.+?)\]'
+ data = re.search(regx, m['Content'])
+ if data:
+ data = data.group(2).split(u'\u3002')[0]
+ else:
+ data = 'You may found detailed info in Content key.'
+ msg = {
+ 'Type': 'Note',
+ 'Text': data, }
+ else:
+ msg = {
+ 'Type': 'Sharing',
+ 'Text': m['FileName'], }
+ elif m['MsgType'] == 51: # phone init
+ msg = update_local_uin(core, m)
+ elif m['MsgType'] == 10000:
+ msg = {
+ 'Type': 'Note',
+ 'Text': m['Content'],}
+ elif m['MsgType'] == 10002:
+ regx = r'\[CDATA\[(.+?)\]\]'
+ data = re.search(regx, m['Content'])
+ data = 'System message' if data is None else data.group(1).replace('\\', '')
+ msg = {
+ 'Type': 'Note',
+ 'Text': data, }
+ elif m['MsgType'] in srl:
+ msg = {
+ 'Type': 'Useless',
+ 'Text': 'UselessMsg', }
+ else:
+ logger.debug('Useless message received: %s\n%s' % (m['MsgType'], str(m)))
+ msg = {
+ 'Type': 'Useless',
+ 'Text': 'UselessMsg', }
+ m = dict(m, **msg)
+ rl.append(m)
+ return rl
+
+def produce_group_chat(core, msg):
+ r = re.match('(@[0-9a-z]*?):
(.*)$', msg['Content'])
+ if r:
+ actualUserName, content = r.groups()
+ chatroomUserName = msg['FromUserName']
+ elif msg['FromUserName'] == core.storageClass.userName:
+ actualUserName = core.storageClass.userName
+ content = msg['Content']
+ chatroomUserName = msg['ToUserName']
+ else:
+ msg['ActualUserName'] = core.storageClass.userName
+ msg['ActualNickName'] = core.storageClass.nickName
+ msg['IsAt'] = False
+ utils.msg_formatter(msg, 'Content')
+ return
+ chatroom = core.storageClass.search_chatrooms(userName=chatroomUserName)
+ member = utils.search_dict_list((chatroom or {}).get(
+ 'MemberList') or [], 'UserName', actualUserName)
+ if member is None:
+ chatroom = core.update_chatroom(chatroomUserName)
+ member = utils.search_dict_list((chatroom or {}).get(
+ 'MemberList') or [], 'UserName', actualUserName)
+ if member is None:
+ logger.debug('chatroom member fetch failed with %s' % actualUserName)
+ msg['ActualNickName'] = ''
+ msg['IsAt'] = False
+ else:
+ msg['ActualNickName'] = member.get('DisplayName', '') or member['NickName']
+ atFlag = '@' + (chatroom['Self'].get('DisplayName', '') or core.storageClass.nickName)
+ msg['IsAt'] = (
+ (atFlag + (u'\u2005' if u'\u2005' in msg['Content'] else ' '))
+ in msg['Content'] or msg['Content'].endswith(atFlag))
+ msg['ActualUserName'] = actualUserName
+ msg['Content'] = content
+ utils.msg_formatter(msg, 'Content')
+
+def send_raw_msg(self, msgType, content, toUserName):
+ url = '%s/webwxsendmsg' % self.loginInfo['url']
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Msg': {
+ 'Type': msgType,
+ 'Content': content,
+ 'FromUserName': self.storageClass.userName,
+ 'ToUserName': (toUserName if toUserName else self.storageClass.userName),
+ 'LocalID': int(time.time() * 1e4),
+ 'ClientMsgId': int(time.time() * 1e4),
+ },
+ 'Scene': 0, }
+ headers = { 'ContentType': 'application/json; charset=UTF-8', 'User-Agent' : config.USER_AGENT }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8'))
+ return ReturnValue(rawResponse=r)
+
+def send_msg(self, msg='Test Message', toUserName=None):
+ logger.debug('Request to send a text message to %s: %s' % (toUserName, msg))
+ r = self.send_raw_msg(1, msg, toUserName)
+ return r
+
+def _prepare_file(fileDir, file_=None):
+ fileDict = {}
+ if file_:
+ if hasattr(file_, 'read'):
+ file_ = file_.read()
+ else:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'file_ param should be opened file',
+ 'Ret': -1005, }})
+ else:
+ if not utils.check_file(fileDir):
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No file found in specific dir',
+ 'Ret': -1002, }})
+ with open(fileDir, 'rb') as f:
+ file_ = f.read()
+ fileDict['fileSize'] = len(file_)
+ fileDict['fileMd5'] = hashlib.md5(file_).hexdigest()
+ fileDict['file_'] = io.BytesIO(file_)
+ return fileDict
+
+def upload_file(self, fileDir, isPicture=False, isVideo=False,
+ toUserName='filehelper', file_=None, preparedFile=None):
+ logger.debug('Request to upload a %s: %s' % (
+ 'picture' if isPicture else 'video' if isVideo else 'file', fileDir))
+ if not preparedFile:
+ preparedFile = _prepare_file(fileDir, file_)
+ if not preparedFile:
+ return preparedFile
+ fileSize, fileMd5, file_ = \
+ preparedFile['fileSize'], preparedFile['fileMd5'], preparedFile['file_']
+ fileSymbol = 'pic' if isPicture else 'video' if isVideo else'doc'
+ chunks = int((fileSize - 1) / 524288) + 1
+ clientMediaId = int(time.time() * 1e4)
+ uploadMediaRequest = json.dumps(OrderedDict([
+ ('UploadType', 2),
+ ('BaseRequest', self.loginInfo['BaseRequest']),
+ ('ClientMediaId', clientMediaId),
+ ('TotalLen', fileSize),
+ ('StartPos', 0),
+ ('DataLen', fileSize),
+ ('MediaType', 4),
+ ('FromUserName', self.storageClass.userName),
+ ('ToUserName', toUserName),
+ ('FileMd5', fileMd5)]
+ ), separators = (',', ':'))
+ r = {'BaseResponse': {'Ret': -1005, 'ErrMsg': 'Empty file detected'}}
+ for chunk in range(chunks):
+ r = upload_chunk_file(self, fileDir, fileSymbol, fileSize,
+ file_, chunk, chunks, uploadMediaRequest)
+ file_.close()
+ if isinstance(r, dict):
+ return ReturnValue(r)
+ return ReturnValue(rawResponse=r)
+
+def upload_chunk_file(core, fileDir, fileSymbol, fileSize,
+ file_, chunk, chunks, uploadMediaRequest):
+ url = core.loginInfo.get('fileUrl', core.loginInfo['url']) + \
+ '/webwxuploadmedia?f=json'
+ # save it on server
+ cookiesList = {name:data for name,data in core.s.cookies.items()}
+ fileType = mimetypes.guess_type(fileDir)[0] or 'application/octet-stream'
+ fileName = utils.quote(os.path.basename(fileDir))
+ files = OrderedDict([
+ ('id', (None, 'WU_FILE_0')),
+ ('name', (None, fileName)),
+ ('type', (None, fileType)),
+ ('lastModifiedDate', (None, time.strftime('%a %b %d %Y %H:%M:%S GMT+0800 (CST)'))),
+ ('size', (None, str(fileSize))),
+ ('chunks', (None, None)),
+ ('chunk', (None, None)),
+ ('mediatype', (None, fileSymbol)),
+ ('uploadmediarequest', (None, uploadMediaRequest)),
+ ('webwx_data_ticket', (None, cookiesList['webwx_data_ticket'])),
+ ('pass_ticket', (None, core.loginInfo['pass_ticket'])),
+ ('filename' , (fileName, file_.read(524288), 'application/octet-stream'))])
+ if chunks == 1:
+ del files['chunk']; del files['chunks']
+ else:
+ files['chunk'], files['chunks'] = (None, str(chunk)), (None, str(chunks))
+ headers = { 'User-Agent' : config.USER_AGENT }
+ return core.s.post(url, files=files, headers=headers, timeout=config.TIMEOUT)
+
+def send_file(self, fileDir, toUserName=None, mediaId=None, file_=None):
+ logger.debug('Request to send a file(mediaId: %s) to %s: %s' % (
+ mediaId, toUserName, fileDir))
+ if hasattr(fileDir, 'read'):
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'fileDir param should not be an opened file in send_file',
+ 'Ret': -1005, }})
+ if toUserName is None:
+ toUserName = self.storageClass.userName
+ preparedFile = _prepare_file(fileDir, file_)
+ if not preparedFile:
+ return preparedFile
+ fileSize = preparedFile['fileSize']
+ if mediaId is None:
+ r = self.upload_file(fileDir, preparedFile=preparedFile)
+ if r:
+ mediaId = r['MediaId']
+ else:
+ return r
+ url = '%s/webwxsendappmsg?fun=async&f=json' % self.loginInfo['url']
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Msg': {
+ 'Type': 6,
+ 'Content': ("%s" % os.path.basename(fileDir) +
+ "6" +
+ "%s%s" % (str(fileSize), mediaId) +
+ "%s" % os.path.splitext(fileDir)[1].replace('.','')),
+ 'FromUserName': self.storageClass.userName,
+ 'ToUserName': toUserName,
+ 'LocalID': int(time.time() * 1e4),
+ 'ClientMsgId': int(time.time() * 1e4), },
+ 'Scene': 0, }
+ headers = {
+ 'User-Agent': config.USER_AGENT,
+ 'Content-Type': 'application/json;charset=UTF-8', }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8'))
+ return ReturnValue(rawResponse=r)
+
+def send_image(self, fileDir=None, toUserName=None, mediaId=None, file_=None):
+ logger.debug('Request to send a image(mediaId: %s) to %s: %s' % (
+ mediaId, toUserName, fileDir))
+ if fileDir or file_:
+ if hasattr(fileDir, 'read'):
+ file_, fileDir = fileDir, None
+ if fileDir is None:
+ fileDir = 'tmp.jpg' # specific fileDir to send gifs
+ else:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Either fileDir or file_ should be specific',
+ 'Ret': -1005, }})
+ if toUserName is None:
+ toUserName = self.storageClass.userName
+ if mediaId is None:
+ r = self.upload_file(fileDir, isPicture=not fileDir[-4:] == '.gif', file_=file_)
+ if r:
+ mediaId = r['MediaId']
+ else:
+ return r
+ url = '%s/webwxsendmsgimg?fun=async&f=json' % self.loginInfo['url']
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Msg': {
+ 'Type': 3,
+ 'MediaId': mediaId,
+ 'FromUserName': self.storageClass.userName,
+ 'ToUserName': toUserName,
+ 'LocalID': int(time.time() * 1e4),
+ 'ClientMsgId': int(time.time() * 1e4), },
+ 'Scene': 0, }
+ if fileDir[-4:] == '.gif':
+ url = '%s/webwxsendemoticon?fun=sys' % self.loginInfo['url']
+ data['Msg']['Type'] = 47
+ data['Msg']['EmojiFlag'] = 2
+ headers = {
+ 'User-Agent': config.USER_AGENT,
+ 'Content-Type': 'application/json;charset=UTF-8', }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8'))
+ return ReturnValue(rawResponse=r)
+
+def send_video(self, fileDir=None, toUserName=None, mediaId=None, file_=None):
+ logger.debug('Request to send a video(mediaId: %s) to %s: %s' % (
+ mediaId, toUserName, fileDir))
+ if fileDir or file_:
+ if hasattr(fileDir, 'read'):
+ file_, fileDir = fileDir, None
+ if fileDir is None:
+ fileDir = 'tmp.mp4' # specific fileDir to send other formats
+ else:
+ return ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'Either fileDir or file_ should be specific',
+ 'Ret': -1005, }})
+ if toUserName is None:
+ toUserName = self.storageClass.userName
+ if mediaId is None:
+ r = self.upload_file(fileDir, isVideo=True, file_=file_)
+ if r:
+ mediaId = r['MediaId']
+ else:
+ return r
+ url = '%s/webwxsendvideomsg?fun=async&f=json&pass_ticket=%s' % (
+ self.loginInfo['url'], self.loginInfo['pass_ticket'])
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ 'Msg': {
+ 'Type' : 43,
+ 'MediaId' : mediaId,
+ 'FromUserName' : self.storageClass.userName,
+ 'ToUserName' : toUserName,
+ 'LocalID' : int(time.time() * 1e4),
+ 'ClientMsgId' : int(time.time() * 1e4), },
+ 'Scene': 0, }
+ headers = {
+ 'User-Agent' : config.USER_AGENT,
+ 'Content-Type': 'application/json;charset=UTF-8', }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8'))
+ return ReturnValue(rawResponse=r)
+
+def send(self, msg, toUserName=None, mediaId=None):
+ if not msg:
+ r = ReturnValue({'BaseResponse': {
+ 'ErrMsg': 'No message.',
+ 'Ret': -1005, }})
+ elif msg[:5] == '@fil@':
+ if mediaId is None:
+ r = self.send_file(msg[5:], toUserName)
+ else:
+ r = self.send_file(msg[5:], toUserName, mediaId)
+ elif msg[:5] == '@img@':
+ if mediaId is None:
+ r = self.send_image(msg[5:], toUserName)
+ else:
+ r = self.send_image(msg[5:], toUserName, mediaId)
+ elif msg[:5] == '@msg@':
+ r = self.send_msg(msg[5:], toUserName)
+ elif msg[:5] == '@vid@':
+ if mediaId is None:
+ r = self.send_video(msg[5:], toUserName)
+ else:
+ r = self.send_video(msg[5:], toUserName, mediaId)
+ else:
+ r = self.send_msg(msg, toUserName)
+ return r
+
+def revoke(self, msgId, toUserName, localId=None):
+ url = '%s/webwxrevokemsg' % self.loginInfo['url']
+ data = {
+ 'BaseRequest': self.loginInfo['BaseRequest'],
+ "ClientMsgId": localId or str(time.time() * 1e3),
+ "SvrMsgId": msgId,
+ "ToUserName": toUserName}
+ headers = {
+ 'ContentType': 'application/json; charset=UTF-8',
+ 'User-Agent' : config.USER_AGENT }
+ r = self.s.post(url, headers=headers,
+ data=json.dumps(data, ensure_ascii=False).encode('utf8'))
+ return ReturnValue(rawResponse=r)
diff --git a/lib/itchat/components/register.py b/lib/itchat/components/register.py
new file mode 100644
index 0000000..e76f2c4
--- /dev/null
+++ b/lib/itchat/components/register.py
@@ -0,0 +1,106 @@
+import logging, traceback, sys, threading
+try:
+ import Queue
+except ImportError:
+ import queue as Queue
+
+from ..log import set_logging
+from ..utils import test_connect
+from ..storage import templates
+
+logger = logging.getLogger('itchat')
+
+def load_register(core):
+ core.auto_login = auto_login
+ core.configured_reply = configured_reply
+ core.msg_register = msg_register
+ core.run = run
+
+def auto_login(self, hotReload=False, statusStorageDir='itchat.pkl',
+ enableCmdQR=False, picDir=None, qrCallback=None,
+ loginCallback=None, exitCallback=None):
+ if not test_connect():
+ logger.info("You can't get access to internet or wechat domain, so exit.")
+ sys.exit()
+ self.useHotReload = hotReload
+ self.hotReloadDir = statusStorageDir
+ if hotReload:
+ rval=self.load_login_status(statusStorageDir,
+ loginCallback=loginCallback, exitCallback=exitCallback)
+ if rval:
+ return
+ logger.error('Hot reload failed, logging in normally, error={}'.format(rval))
+ self.logout()
+ self.login(enableCmdQR=enableCmdQR, picDir=picDir, qrCallback=qrCallback,
+ loginCallback=loginCallback, exitCallback=exitCallback)
+ self.dump_login_status(statusStorageDir)
+ else:
+ self.login(enableCmdQR=enableCmdQR, picDir=picDir, qrCallback=qrCallback,
+ loginCallback=loginCallback, exitCallback=exitCallback)
+
+def configured_reply(self):
+ ''' determine the type of message and reply if its method is defined
+ however, I use a strange way to determine whether a msg is from massive platform
+ I haven't found a better solution here
+ The main problem I'm worrying about is the mismatching of new friends added on phone
+ If you have any good idea, pleeeease report an issue. I will be more than grateful.
+ '''
+ try:
+ msg = self.msgList.get(timeout=1)
+ except Queue.Empty:
+ pass
+ else:
+ if isinstance(msg['User'], templates.User):
+ replyFn = self.functionDict['FriendChat'].get(msg['Type'])
+ elif isinstance(msg['User'], templates.MassivePlatform):
+ replyFn = self.functionDict['MpChat'].get(msg['Type'])
+ elif isinstance(msg['User'], templates.Chatroom):
+ replyFn = self.functionDict['GroupChat'].get(msg['Type'])
+ if replyFn is None:
+ r = None
+ else:
+ try:
+ r = replyFn(msg)
+ if r is not None:
+ self.send(r, msg.get('FromUserName'))
+ except:
+ logger.warning(traceback.format_exc())
+
+def msg_register(self, msgType, isFriendChat=False, isGroupChat=False, isMpChat=False):
+ ''' a decorator constructor
+ return a specific decorator based on information given '''
+ if not (isinstance(msgType, list) or isinstance(msgType, tuple)):
+ msgType = [msgType]
+ def _msg_register(fn):
+ for _msgType in msgType:
+ if isFriendChat:
+ self.functionDict['FriendChat'][_msgType] = fn
+ if isGroupChat:
+ self.functionDict['GroupChat'][_msgType] = fn
+ if isMpChat:
+ self.functionDict['MpChat'][_msgType] = fn
+ if not any((isFriendChat, isGroupChat, isMpChat)):
+ self.functionDict['FriendChat'][_msgType] = fn
+ return fn
+ return _msg_register
+
+def run(self, debug=False, blockThread=True):
+ logger.info('Start auto replying.')
+ if debug:
+ set_logging(loggingLevel=logging.DEBUG)
+ def reply_fn():
+ try:
+ while self.alive:
+ self.configured_reply()
+ except KeyboardInterrupt:
+ if self.useHotReload:
+ self.dump_login_status()
+ self.alive = False
+ logger.debug('itchat received an ^C and exit.')
+ logger.info('Bye~')
+ if blockThread:
+ reply_fn()
+ else:
+ replyThread = threading.Thread(target=reply_fn)
+ replyThread.setDaemon(True)
+ replyThread.start()
diff --git a/lib/itchat/config.py b/lib/itchat/config.py
new file mode 100644
index 0000000..2ac6328
--- /dev/null
+++ b/lib/itchat/config.py
@@ -0,0 +1,17 @@
+import os, platform
+
+VERSION = '1.5.0.dev'
+
+# use this envrionment to initialize the async & sync componment
+ASYNC_COMPONENTS = os.environ.get('ITCHAT_UOS_ASYNC', False)
+
+BASE_URL = 'https://login.weixin.qq.com'
+OS = platform.system() # Windows, Linux, Darwin
+DIR = os.getcwd()
+DEFAULT_QR = 'QR.png'
+TIMEOUT = (10, 60)
+
+USER_AGENT = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/54.0.2840.71 Safari/537.36'
+
+UOS_PATCH_CLIENT_VERSION = '2.0.0'
+UOS_PATCH_EXTSPAM = 'Go8FCIkFEokFCggwMDAwMDAwMRAGGvAESySibk50w5Wb3uTl2c2h64jVVrV7gNs06GFlWplHQbY/5FfiO++1yH4ykCyNPWKXmco+wfQzK5R98D3so7rJ5LmGFvBLjGceleySrc3SOf2Pc1gVehzJgODeS0lDL3/I/0S2SSE98YgKleq6Uqx6ndTy9yaL9qFxJL7eiA/R3SEfTaW1SBoSITIu+EEkXff+Pv8NHOk7N57rcGk1w0ZzRrQDkXTOXFN2iHYIzAAZPIOY45Lsh+A4slpgnDiaOvRtlQYCt97nmPLuTipOJ8Qc5pM7ZsOsAPPrCQL7nK0I7aPrFDF0q4ziUUKettzW8MrAaiVfmbD1/VkmLNVqqZVvBCtRblXb5FHmtS8FxnqCzYP4WFvz3T0TcrOqwLX1M/DQvcHaGGw0B0y4bZMs7lVScGBFxMj3vbFi2SRKbKhaitxHfYHAOAa0X7/MSS0RNAjdwoyGHeOepXOKY+h3iHeqCvgOH6LOifdHf/1aaZNwSkGotYnYScW8Yx63LnSwba7+hESrtPa/huRmB9KWvMCKbDThL/nne14hnL277EDCSocPu3rOSYjuB9gKSOdVmWsj9Dxb/iZIe+S6AiG29Esm+/eUacSba0k8wn5HhHg9d4tIcixrxveflc8vi2/wNQGVFNsGO6tB5WF0xf/plngOvQ1/ivGV/C1Qpdhzznh0ExAVJ6dwzNg7qIEBaw+BzTJTUuRcPk92Sn6QDn2Pu3mpONaEumacjW4w6ipPnPw+g2TfywJjeEcpSZaP4Q3YV5HG8D6UjWA4GSkBKculWpdCMadx0usMomsSS/74QgpYqcPkmamB4nVv1JxczYITIqItIKjD35IGKAUwAA=='
diff --git a/lib/itchat/content.py b/lib/itchat/content.py
new file mode 100644
index 0000000..41dc0b1
--- /dev/null
+++ b/lib/itchat/content.py
@@ -0,0 +1,14 @@
+TEXT = 'Text'
+MAP = 'Map'
+CARD = 'Card'
+NOTE = 'Note'
+SHARING = 'Sharing'
+PICTURE = 'Picture'
+RECORDING = VOICE = 'Recording'
+ATTACHMENT = 'Attachment'
+VIDEO = 'Video'
+FRIENDS = 'Friends'
+SYSTEM = 'System'
+
+INCOME_MSG = [TEXT, MAP, CARD, NOTE, SHARING, PICTURE,
+ RECORDING, VOICE, ATTACHMENT, VIDEO, FRIENDS, SYSTEM]
diff --git a/lib/itchat/core.py b/lib/itchat/core.py
new file mode 100644
index 0000000..f3871b5
--- /dev/null
+++ b/lib/itchat/core.py
@@ -0,0 +1,456 @@
+import requests
+
+from . import storage
+
+class Core(object):
+ def __init__(self):
+ ''' init is the only method defined in core.py
+ alive is value showing whether core is running
+ - you should call logout method to change it
+ - after logout, a core object can login again
+ storageClass only uses basic python types
+ - so for advanced uses, inherit it yourself
+ receivingRetryCount is for receiving loop retry
+ - it's 5 now, but actually even 1 is enough
+ - failing is failing
+ '''
+ self.alive, self.isLogging = False, False
+ self.storageClass = storage.Storage(self)
+ self.memberList = self.storageClass.memberList
+ self.mpList = self.storageClass.mpList
+ self.chatroomList = self.storageClass.chatroomList
+ self.msgList = self.storageClass.msgList
+ self.loginInfo = {}
+ self.s = requests.Session()
+ self.uuid = None
+ self.functionDict = {'FriendChat': {}, 'GroupChat': {}, 'MpChat': {}}
+ self.useHotReload, self.hotReloadDir = False, 'itchat.pkl'
+ self.receivingRetryCount = 5
+ def login(self, enableCmdQR=False, picDir=None, qrCallback=None,
+ loginCallback=None, exitCallback=None):
+ ''' log in like web wechat does
+ for log in
+ - a QR code will be downloaded and opened
+ - then scanning status is logged, it paused for you confirm
+ - finally it logged in and show your nickName
+ for options
+ - enableCmdQR: show qrcode in command line
+ - integers can be used to fit strange char length
+ - picDir: place for storing qrcode
+ - qrCallback: method that should accept uuid, status, qrcode
+ - loginCallback: callback after successfully logged in
+ - if not set, screen is cleared and qrcode is deleted
+ - exitCallback: callback after logged out
+ - it contains calling of logout
+ for usage
+ ..code::python
+
+ import itchat
+ itchat.login()
+
+ it is defined in components/login.py
+ and of course every single move in login can be called outside
+ - you may scan source code to see how
+ - and modified according to your own demand
+ '''
+ raise NotImplementedError()
+ def get_QRuuid(self):
+ ''' get uuid for qrcode
+ uuid is the symbol of qrcode
+ - for logging in, you need to get a uuid first
+ - for downloading qrcode, you need to pass uuid to it
+ - for checking login status, uuid is also required
+ if uuid has timed out, just get another
+ it is defined in components/login.py
+ '''
+ raise NotImplementedError()
+ def get_QR(self, uuid=None, enableCmdQR=False, picDir=None, qrCallback=None):
+ ''' download and show qrcode
+ for options
+ - uuid: if uuid is not set, latest uuid you fetched will be used
+ - enableCmdQR: show qrcode in cmd
+ - picDir: where to store qrcode
+ - qrCallback: method that should accept uuid, status, qrcode
+ it is defined in components/login.py
+ '''
+ raise NotImplementedError()
+ def check_login(self, uuid=None):
+ ''' check login status
+ for options:
+ - uuid: if uuid is not set, latest uuid you fetched will be used
+ for return values:
+ - a string will be returned
+ - for meaning of return values
+ - 200: log in successfully
+ - 201: waiting for press confirm
+ - 408: uuid timed out
+ - 0 : unknown error
+ for processing:
+ - syncUrl and fileUrl is set
+ - BaseRequest is set
+ blocks until reaches any of above status
+ it is defined in components/login.py
+ '''
+ raise NotImplementedError()
+ def web_init(self):
+ ''' get info necessary for initializing
+ for processing:
+ - own account info is set
+ - inviteStartCount is set
+ - syncKey is set
+ - part of contact is fetched
+ it is defined in components/login.py
+ '''
+ raise NotImplementedError()
+ def show_mobile_login(self):
+ ''' show web wechat login sign
+ the sign is on the top of mobile phone wechat
+ sign will be added after sometime even without calling this function
+ it is defined in components/login.py
+ '''
+ raise NotImplementedError()
+ def start_receiving(self, exitCallback=None, getReceivingFnOnly=False):
+ ''' open a thread for heart loop and receiving messages
+ for options:
+ - exitCallback: callback after logged out
+ - it contains calling of logout
+ - getReceivingFnOnly: if True thread will not be created and started. Instead, receive fn will be returned.
+ for processing:
+ - messages: msgs are formatted and passed on to registered fns
+ - contact : chatrooms are updated when related info is received
+ it is defined in components/login.py
+ '''
+ raise NotImplementedError()
+ def get_msg(self):
+ ''' fetch messages
+ for fetching
+ - method blocks for sometime until
+ - new messages are to be received
+ - or anytime they like
+ - synckey is updated with returned synccheckkey
+ it is defined in components/login.py
+ '''
+ raise NotImplementedError()
+ def logout(self):
+ ''' logout
+ if core is now alive
+ logout will tell wechat backstage to logout
+ and core gets ready for another login
+ it is defined in components/login.py
+ '''
+ raise NotImplementedError()
+ def update_chatroom(self, userName, detailedMember=False):
+ ''' update chatroom
+ for chatroom contact
+ - a chatroom contact need updating to be detailed
+ - detailed means members, encryid, etc
+ - auto updating of heart loop is a more detailed updating
+ - member uin will also be filled
+ - once called, updated info will be stored
+ for options
+ - userName: 'UserName' key of chatroom or a list of it
+ - detailedMember: whether to get members of contact
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def update_friend(self, userName):
+ ''' update chatroom
+ for friend contact
+ - once called, updated info will be stored
+ for options
+ - userName: 'UserName' key of a friend or a list of it
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def get_contact(self, update=False):
+ ''' fetch part of contact
+ for part
+ - all the massive platforms and friends are fetched
+ - if update, only starred chatrooms are fetched
+ for options
+ - update: if not set, local value will be returned
+ for results
+ - chatroomList will be returned
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def get_friends(self, update=False):
+ ''' fetch friends list
+ for options
+ - update: if not set, local value will be returned
+ for results
+ - a list of friends' info dicts will be returned
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def get_chatrooms(self, update=False, contactOnly=False):
+ ''' fetch chatrooms list
+ for options
+ - update: if not set, local value will be returned
+ - contactOnly: if set, only starred chatrooms will be returned
+ for results
+ - a list of chatrooms' info dicts will be returned
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def get_mps(self, update=False):
+ ''' fetch massive platforms list
+ for options
+ - update: if not set, local value will be returned
+ for results
+ - a list of platforms' info dicts will be returned
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def set_alias(self, userName, alias):
+ ''' set alias for a friend
+ for options
+ - userName: 'UserName' key of info dict
+ - alias: new alias
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def set_pinned(self, userName, isPinned=True):
+ ''' set pinned for a friend or a chatroom
+ for options
+ - userName: 'UserName' key of info dict
+ - isPinned: whether to pin
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def accept_friend(self, userName, v4,autoUpdate=True):
+ ''' accept a friend or accept a friend
+ for options
+ - userName: 'UserName' for friend's info dict
+ - status:
+ - for adding status should be 2
+ - for accepting status should be 3
+ - ticket: greeting message
+ - userInfo: friend's other info for adding into local storage
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def get_head_img(self, userName=None, chatroomUserName=None, picDir=None):
+ ''' place for docs
+ for options
+ - if you want to get chatroom header: only set chatroomUserName
+ - if you want to get friend header: only set userName
+ - if you want to get chatroom member header: set both
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def create_chatroom(self, memberList, topic=''):
+ ''' create a chatroom
+ for creating
+ - its calling frequency is strictly limited
+ for options
+ - memberList: list of member info dict
+ - topic: topic of new chatroom
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def set_chatroom_name(self, chatroomUserName, name):
+ ''' set chatroom name
+ for setting
+ - it makes an updating of chatroom
+ - which means detailed info will be returned in heart loop
+ for options
+ - chatroomUserName: 'UserName' key of chatroom info dict
+ - name: new chatroom name
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def delete_member_from_chatroom(self, chatroomUserName, memberList):
+ ''' deletes members from chatroom
+ for deleting
+ - you can't delete yourself
+ - if so, no one will be deleted
+ - strict-limited frequency
+ for options
+ - chatroomUserName: 'UserName' key of chatroom info dict
+ - memberList: list of members' info dict
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def add_member_into_chatroom(self, chatroomUserName, memberList,
+ useInvitation=False):
+ ''' add members into chatroom
+ for adding
+ - you can't add yourself or member already in chatroom
+ - if so, no one will be added
+ - if member will over 40 after adding, invitation must be used
+ - strict-limited frequency
+ for options
+ - chatroomUserName: 'UserName' key of chatroom info dict
+ - memberList: list of members' info dict
+ - useInvitation: if invitation is not required, set this to use
+ it is defined in components/contact.py
+ '''
+ raise NotImplementedError()
+ def send_raw_msg(self, msgType, content, toUserName):
+ ''' many messages are sent in a common way
+ for demo
+ .. code:: python
+
+ @itchat.msg_register(itchat.content.CARD)
+ def reply(msg):
+ itchat.send_raw_msg(msg['MsgType'], msg['Content'], msg['FromUserName'])
+
+ there are some little tricks here, you may discover them yourself
+ but remember they are tricks
+ it is defined in components/messages.py
+ '''
+ raise NotImplementedError()
+ def send_msg(self, msg='Test Message', toUserName=None):
+ ''' send plain text message
+ for options
+ - msg: should be unicode if there's non-ascii words in msg
+ - toUserName: 'UserName' key of friend dict
+ it is defined in components/messages.py
+ '''
+ raise NotImplementedError()
+ def upload_file(self, fileDir, isPicture=False, isVideo=False,
+ toUserName='filehelper', file_=None, preparedFile=None):
+ ''' upload file to server and get mediaId
+ for options
+ - fileDir: dir for file ready for upload
+ - isPicture: whether file is a picture
+ - isVideo: whether file is a video
+ for return values
+ will return a ReturnValue
+ if succeeded, mediaId is in r['MediaId']
+ it is defined in components/messages.py
+ '''
+ raise NotImplementedError()
+ def send_file(self, fileDir, toUserName=None, mediaId=None, file_=None):
+ ''' send attachment
+ for options
+ - fileDir: dir for file ready for upload
+ - mediaId: mediaId for file.
+ - if set, file will not be uploaded twice
+ - toUserName: 'UserName' key of friend dict
+ it is defined in components/messages.py
+ '''
+ raise NotImplementedError()
+ def send_image(self, fileDir=None, toUserName=None, mediaId=None, file_=None):
+ ''' send image
+ for options
+ - fileDir: dir for file ready for upload
+ - if it's a gif, name it like 'xx.gif'
+ - mediaId: mediaId for file.
+ - if set, file will not be uploaded twice
+ - toUserName: 'UserName' key of friend dict
+ it is defined in components/messages.py
+ '''
+ raise NotImplementedError()
+ def send_video(self, fileDir=None, toUserName=None, mediaId=None, file_=None):
+ ''' send video
+ for options
+ - fileDir: dir for file ready for upload
+ - if mediaId is set, it's unnecessary to set fileDir
+ - mediaId: mediaId for file.
+ - if set, file will not be uploaded twice
+ - toUserName: 'UserName' key of friend dict
+ it is defined in components/messages.py
+ '''
+ raise NotImplementedError()
+ def send(self, msg, toUserName=None, mediaId=None):
+ ''' wrapped function for all the sending functions
+ for options
+ - msg: message starts with different string indicates different type
+ - list of type string: ['@fil@', '@img@', '@msg@', '@vid@']
+ - they are for file, image, plain text, video
+ - if none of them matches, it will be sent like plain text
+ - toUserName: 'UserName' key of friend dict
+ - mediaId: if set, uploading will not be repeated
+ it is defined in components/messages.py
+ '''
+ raise NotImplementedError()
+ def revoke(self, msgId, toUserName, localId=None):
+ ''' revoke message with its and msgId
+ for options
+ - msgId: message Id on server
+ - toUserName: 'UserName' key of friend dict
+ - localId: message Id at local (optional)
+ it is defined in components/messages.py
+ '''
+ raise NotImplementedError()
+ def dump_login_status(self, fileDir=None):
+ ''' dump login status to a specific file
+ for option
+ - fileDir: dir for dumping login status
+ it is defined in components/hotreload.py
+ '''
+ raise NotImplementedError()
+ def load_login_status(self, fileDir,
+ loginCallback=None, exitCallback=None):
+ ''' load login status from a specific file
+ for option
+ - fileDir: file for loading login status
+ - loginCallback: callback after successfully logged in
+ - if not set, screen is cleared and qrcode is deleted
+ - exitCallback: callback after logged out
+ - it contains calling of logout
+ it is defined in components/hotreload.py
+ '''
+ raise NotImplementedError()
+ def auto_login(self, hotReload=False, statusStorageDir='itchat.pkl',
+ enableCmdQR=False, picDir=None, qrCallback=None,
+ loginCallback=None, exitCallback=None):
+ ''' log in like web wechat does
+ for log in
+ - a QR code will be downloaded and opened
+ - then scanning status is logged, it paused for you confirm
+ - finally it logged in and show your nickName
+ for options
+ - hotReload: enable hot reload
+ - statusStorageDir: dir for storing log in status
+ - enableCmdQR: show qrcode in command line
+ - integers can be used to fit strange char length
+ - picDir: place for storing qrcode
+ - loginCallback: callback after successfully logged in
+ - if not set, screen is cleared and qrcode is deleted
+ - exitCallback: callback after logged out
+ - it contains calling of logout
+ - qrCallback: method that should accept uuid, status, qrcode
+ for usage
+ ..code::python
+
+ import itchat
+ itchat.auto_login()
+
+ it is defined in components/register.py
+ and of course every single move in login can be called outside
+ - you may scan source code to see how
+ - and modified according to your own demond
+ '''
+ raise NotImplementedError()
+ def configured_reply(self):
+ ''' determine the type of message and reply if its method is defined
+ however, I use a strange way to determine whether a msg is from massive platform
+ I haven't found a better solution here
+ The main problem I'm worrying about is the mismatching of new friends added on phone
+ If you have any good idea, pleeeease report an issue. I will be more than grateful.
+ '''
+ raise NotImplementedError()
+ def msg_register(self, msgType,
+ isFriendChat=False, isGroupChat=False, isMpChat=False):
+ ''' a decorator constructor
+ return a specific decorator based on information given
+ '''
+ raise NotImplementedError()
+ def run(self, debug=True, blockThread=True):
+ ''' start auto respond
+ for option
+ - debug: if set, debug info will be shown on screen
+ it is defined in components/register.py
+ '''
+ raise NotImplementedError()
+ def search_friends(self, name=None, userName=None, remarkName=None, nickName=None,
+ wechatAccount=None):
+ return self.storageClass.search_friends(name, userName, remarkName,
+ nickName, wechatAccount)
+ def search_chatrooms(self, name=None, userName=None):
+ return self.storageClass.search_chatrooms(name, userName)
+ def search_mps(self, name=None, userName=None):
+ return self.storageClass.search_mps(name, userName)
diff --git a/lib/itchat/log.py b/lib/itchat/log.py
new file mode 100644
index 0000000..4485cc9
--- /dev/null
+++ b/lib/itchat/log.py
@@ -0,0 +1,36 @@
+import logging
+
+class LogSystem(object):
+ handlerList = []
+ showOnCmd = True
+ loggingLevel = logging.INFO
+ loggingFile = None
+ def __init__(self):
+ self.logger = logging.getLogger('itchat')
+ self.logger.addHandler(logging.NullHandler())
+ self.logger.setLevel(self.loggingLevel)
+ self.cmdHandler = logging.StreamHandler()
+ self.fileHandler = None
+ self.logger.addHandler(self.cmdHandler)
+ def set_logging(self, showOnCmd=True, loggingFile=None,
+ loggingLevel=logging.INFO):
+ if showOnCmd != self.showOnCmd:
+ if showOnCmd:
+ self.logger.addHandler(self.cmdHandler)
+ else:
+ self.logger.removeHandler(self.cmdHandler)
+ self.showOnCmd = showOnCmd
+ if loggingFile != self.loggingFile:
+ if self.loggingFile is not None: # clear old fileHandler
+ self.logger.removeHandler(self.fileHandler)
+ self.fileHandler.close()
+ if loggingFile is not None: # add new fileHandler
+ self.fileHandler = logging.FileHandler(loggingFile)
+ self.logger.addHandler(self.fileHandler)
+ self.loggingFile = loggingFile
+ if loggingLevel != self.loggingLevel:
+ self.logger.setLevel(loggingLevel)
+ self.loggingLevel = loggingLevel
+
+ls = LogSystem()
+set_logging = ls.set_logging
diff --git a/lib/itchat/returnvalues.py b/lib/itchat/returnvalues.py
new file mode 100644
index 0000000..f42f4e8
--- /dev/null
+++ b/lib/itchat/returnvalues.py
@@ -0,0 +1,67 @@
+#coding=utf8
+TRANSLATE = 'Chinese'
+
+class ReturnValue(dict):
+ ''' turn return value of itchat into a boolean value
+ for requests:
+ ..code::python
+
+ import requests
+ r = requests.get('http://httpbin.org/get')
+ print(ReturnValue(rawResponse=r)
+
+ for normal dict:
+ ..code::python
+
+ returnDict = {
+ 'BaseResponse': {
+ 'Ret': 0,
+ 'ErrMsg': 'My error msg', }, }
+ print(ReturnValue(returnDict))
+ '''
+ def __init__(self, returnValueDict={}, rawResponse=None):
+ if rawResponse:
+ try:
+ returnValueDict = rawResponse.json()
+ except ValueError:
+ returnValueDict = {
+ 'BaseResponse': {
+ 'Ret': -1004,
+ 'ErrMsg': 'Unexpected return value', },
+ 'Data': rawResponse.content, }
+ for k, v in returnValueDict.items():
+ self[k] = v
+ if not 'BaseResponse' in self:
+ self['BaseResponse'] = {
+ 'ErrMsg': 'no BaseResponse in raw response',
+ 'Ret': -1000, }
+ if TRANSLATE:
+ self['BaseResponse']['RawMsg'] = self['BaseResponse'].get('ErrMsg', '')
+ self['BaseResponse']['ErrMsg'] = \
+ TRANSLATION[TRANSLATE].get(
+ self['BaseResponse'].get('Ret', '')) \
+ or self['BaseResponse'].get('ErrMsg', u'No ErrMsg')
+ self['BaseResponse']['RawMsg'] = \
+ self['BaseResponse']['RawMsg'] or self['BaseResponse']['ErrMsg']
+ def __nonzero__(self):
+ return self['BaseResponse'].get('Ret') == 0
+ def __bool__(self):
+ return self.__nonzero__()
+ def __str__(self):
+ return '{%s}' % ', '.join(
+ ['%s: %s' % (repr(k),repr(v)) for k,v in self.items()])
+ def __repr__(self):
+ return '' % self.__str__()
+
+TRANSLATION = {
+ 'Chinese': {
+ -1000: u'返回值不带BaseResponse',
+ -1001: u'无法找到对应的成员',
+ -1002: u'文件位置错误',
+ -1003: u'服务器拒绝连接',
+ -1004: u'服务器返回异常值',
+ -1005: u'参数错误',
+ -1006: u'无效操作',
+ 0: u'请求成功',
+ },
+}
diff --git a/lib/itchat/storage/__init__.py b/lib/itchat/storage/__init__.py
new file mode 100644
index 0000000..5c65724
--- /dev/null
+++ b/lib/itchat/storage/__init__.py
@@ -0,0 +1,117 @@
+import os, time, copy
+from threading import Lock
+
+from .messagequeue import Queue
+from .templates import (
+ ContactList, AbstractUserDict, User,
+ MassivePlatform, Chatroom, ChatroomMember)
+
+def contact_change(fn):
+ def _contact_change(core, *args, **kwargs):
+ with core.storageClass.updateLock:
+ return fn(core, *args, **kwargs)
+ return _contact_change
+
+class Storage(object):
+ def __init__(self, core):
+ self.userName = None
+ self.nickName = None
+ self.updateLock = Lock()
+ self.memberList = ContactList()
+ self.mpList = ContactList()
+ self.chatroomList = ContactList()
+ self.msgList = Queue(-1)
+ self.lastInputUserName = None
+ self.memberList.set_default_value(contactClass=User)
+ self.memberList.core = core
+ self.mpList.set_default_value(contactClass=MassivePlatform)
+ self.mpList.core = core
+ self.chatroomList.set_default_value(contactClass=Chatroom)
+ self.chatroomList.core = core
+ def dumps(self):
+ return {
+ 'userName' : self.userName,
+ 'nickName' : self.nickName,
+ 'memberList' : self.memberList,
+ 'mpList' : self.mpList,
+ 'chatroomList' : self.chatroomList,
+ 'lastInputUserName' : self.lastInputUserName, }
+ def loads(self, j):
+ self.userName = j.get('userName', None)
+ self.nickName = j.get('nickName', None)
+ del self.memberList[:]
+ for i in j.get('memberList', []):
+ self.memberList.append(i)
+ del self.mpList[:]
+ for i in j.get('mpList', []):
+ self.mpList.append(i)
+ del self.chatroomList[:]
+ for i in j.get('chatroomList', []):
+ self.chatroomList.append(i)
+ # I tried to solve everything in pickle
+ # but this way is easier and more storage-saving
+ for chatroom in self.chatroomList:
+ if 'MemberList' in chatroom:
+ for member in chatroom['MemberList']:
+ member.core = chatroom.core
+ member.chatroom = chatroom
+ if 'Self' in chatroom:
+ chatroom['Self'].core = chatroom.core
+ chatroom['Self'].chatroom = chatroom
+ self.lastInputUserName = j.get('lastInputUserName', None)
+ def search_friends(self, name=None, userName=None, remarkName=None, nickName=None,
+ wechatAccount=None):
+ with self.updateLock:
+ if (name or userName or remarkName or nickName or wechatAccount) is None:
+ return copy.deepcopy(self.memberList[0]) # my own account
+ elif userName: # return the only userName match
+ for m in self.memberList:
+ if m['UserName'] == userName:
+ return copy.deepcopy(m)
+ else:
+ matchDict = {
+ 'RemarkName' : remarkName,
+ 'NickName' : nickName,
+ 'Alias' : wechatAccount, }
+ for k in ('RemarkName', 'NickName', 'Alias'):
+ if matchDict[k] is None:
+ del matchDict[k]
+ if name: # select based on name
+ contact = []
+ for m in self.memberList:
+ if any([m.get(k) == name for k in ('RemarkName', 'NickName', 'Alias')]):
+ contact.append(m)
+ else:
+ contact = self.memberList[:]
+ if matchDict: # select again based on matchDict
+ friendList = []
+ for m in contact:
+ if all([m.get(k) == v for k, v in matchDict.items()]):
+ friendList.append(m)
+ return copy.deepcopy(friendList)
+ else:
+ return copy.deepcopy(contact)
+ def search_chatrooms(self, name=None, userName=None):
+ with self.updateLock:
+ if userName is not None:
+ for m in self.chatroomList:
+ if m['UserName'] == userName:
+ return copy.deepcopy(m)
+ elif name is not None:
+ matchList = []
+ for m in self.chatroomList:
+ if name in m['NickName']:
+ matchList.append(copy.deepcopy(m))
+ return matchList
+ def search_mps(self, name=None, userName=None):
+ with self.updateLock:
+ if userName is not None:
+ for m in self.mpList:
+ if m['UserName'] == userName:
+ return copy.deepcopy(m)
+ elif name is not None:
+ matchList = []
+ for m in self.mpList:
+ if name in m['NickName']:
+ matchList.append(copy.deepcopy(m))
+ return matchList
diff --git a/lib/itchat/storage/messagequeue.py b/lib/itchat/storage/messagequeue.py
new file mode 100644
index 0000000..53ed669
--- /dev/null
+++ b/lib/itchat/storage/messagequeue.py
@@ -0,0 +1,32 @@
+import logging
+try:
+ import Queue as queue
+except ImportError:
+ import queue
+
+from .templates import AttributeDict
+
+logger = logging.getLogger('itchat')
+
+class Queue(queue.Queue):
+ def put(self, message):
+ queue.Queue.put(self, Message(message))
+
+class Message(AttributeDict):
+ def download(self, fileName):
+ if hasattr(self.text, '__call__'):
+ return self.text(fileName)
+ else:
+ return b''
+ def __getitem__(self, value):
+ if value in ('isAdmin', 'isAt'):
+ v = value[0].upper() + value[1:] # ''[1:] == ''
+ logger.debug('%s is expired in 1.3.0, use %s instead.' % (value, v))
+ value = v
+ return super(Message, self).__getitem__(value)
+ def __str__(self):
+ return '{%s}' % ', '.join(
+ ['%s: %s' % (repr(k),repr(v)) for k,v in self.items()])
+ def __repr__(self):
+ return '<%s: %s>' % (self.__class__.__name__.split('.')[-1],
+ self.__str__())
diff --git a/lib/itchat/storage/templates.py b/lib/itchat/storage/templates.py
new file mode 100644
index 0000000..6a670d7
--- /dev/null
+++ b/lib/itchat/storage/templates.py
@@ -0,0 +1,318 @@
+import logging, copy, pickle
+from weakref import ref
+
+from ..returnvalues import ReturnValue
+from ..utils import update_info_dict
+
+logger = logging.getLogger('itchat')
+
+class AttributeDict(dict):
+ def __getattr__(self, value):
+ keyName = value[0].upper() + value[1:]
+ try:
+ return self[keyName]
+ except KeyError:
+ raise AttributeError("'%s' object has no attribute '%s'" % (
+ self.__class__.__name__.split('.')[-1], keyName))
+ def get(self, v, d=None):
+ try:
+ return self[v]
+ except KeyError:
+ return d
+
+class UnInitializedItchat(object):
+ def _raise_error(self, *args, **kwargs):
+ logger.warning('An itchat instance is called before initialized')
+ def __getattr__(self, value):
+ return self._raise_error
+
+class ContactList(list):
+ ''' when a dict is append, init function will be called to format that dict '''
+ def __init__(self, *args, **kwargs):
+ super(ContactList, self).__init__(*args, **kwargs)
+ self.__setstate__(None)
+ @property
+ def core(self):
+ return getattr(self, '_core', lambda: fakeItchat)() or fakeItchat
+ @core.setter
+ def core(self, value):
+ self._core = ref(value)
+ def set_default_value(self, initFunction=None, contactClass=None):
+ if hasattr(initFunction, '__call__'):
+ self.contactInitFn = initFunction
+ if hasattr(contactClass, '__call__'):
+ self.contactClass = contactClass
+ def append(self, value):
+ contact = self.contactClass(value)
+ contact.core = self.core
+ if self.contactInitFn is not None:
+ contact = self.contactInitFn(self, contact) or contact
+ super(ContactList, self).append(contact)
+ def __deepcopy__(self, memo):
+ r = self.__class__([copy.deepcopy(v) for v in self])
+ r.contactInitFn = self.contactInitFn
+ r.contactClass = self.contactClass
+ r.core = self.core
+ return r
+ def __getstate__(self):
+ return 1
+ def __setstate__(self, state):
+ self.contactInitFn = None
+ self.contactClass = User
+ def __str__(self):
+ return '[%s]' % ', '.join([repr(v) for v in self])
+ def __repr__(self):
+ return '<%s: %s>' % (self.__class__.__name__.split('.')[-1],
+ self.__str__())
+
+class AbstractUserDict(AttributeDict):
+ def __init__(self, *args, **kwargs):
+ super(AbstractUserDict, self).__init__(*args, **kwargs)
+ @property
+ def core(self):
+ return getattr(self, '_core', lambda: fakeItchat)() or fakeItchat
+ @core.setter
+ def core(self, value):
+ self._core = ref(value)
+ def update(self):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not be updated' % \
+ self.__class__.__name__, }, })
+ def set_alias(self, alias):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not set alias' % \
+ self.__class__.__name__, }, })
+ def set_pinned(self, isPinned=True):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not be pinned' % \
+ self.__class__.__name__, }, })
+ def verify(self):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s do not need verify' % \
+ self.__class__.__name__, }, })
+ def get_head_image(self, imageDir=None):
+ return self.core.get_head_img(self.userName, picDir=imageDir)
+ def delete_member(self, userName):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not delete member' % \
+ self.__class__.__name__, }, })
+ def add_member(self, userName):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not add member' % \
+ self.__class__.__name__, }, })
+ def send_raw_msg(self, msgType, content):
+ return self.core.send_raw_msg(msgType, content, self.userName)
+ def send_msg(self, msg='Test Message'):
+ return self.core.send_msg(msg, self.userName)
+ def send_file(self, fileDir, mediaId=None):
+ return self.core.send_file(fileDir, self.userName, mediaId)
+ def send_image(self, fileDir, mediaId=None):
+ return self.core.send_image(fileDir, self.userName, mediaId)
+ def send_video(self, fileDir=None, mediaId=None):
+ return self.core.send_video(fileDir, self.userName, mediaId)
+ def send(self, msg, mediaId=None):
+ return self.core.send(msg, self.userName, mediaId)
+ def search_member(self, name=None, userName=None, remarkName=None, nickName=None,
+ wechatAccount=None):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s do not have members' % \
+ self.__class__.__name__, }, })
+ def __deepcopy__(self, memo):
+ r = self.__class__()
+ for k, v in self.items():
+ r[copy.deepcopy(k)] = copy.deepcopy(v)
+ r.core = self.core
+ return r
+ def __str__(self):
+ return '{%s}' % ', '.join(
+ ['%s: %s' % (repr(k),repr(v)) for k,v in self.items()])
+ def __repr__(self):
+ return '<%s: %s>' % (self.__class__.__name__.split('.')[-1],
+ self.__str__())
+ def __getstate__(self):
+ return 1
+ def __setstate__(self, state):
+ pass
+
+class User(AbstractUserDict):
+ def __init__(self, *args, **kwargs):
+ super(User, self).__init__(*args, **kwargs)
+ self.__setstate__(None)
+ def update(self):
+ r = self.core.update_friend(self.userName)
+ if r:
+ update_info_dict(self, r)
+ return r
+ def set_alias(self, alias):
+ return self.core.set_alias(self.userName, alias)
+ def set_pinned(self, isPinned=True):
+ return self.core.set_pinned(self.userName, isPinned)
+ def verify(self):
+ return self.core.add_friend(**self.verifyDict)
+ def __deepcopy__(self, memo):
+ r = super(User, self).__deepcopy__(memo)
+ r.verifyDict = copy.deepcopy(self.verifyDict)
+ return r
+ def __setstate__(self, state):
+ super(User, self).__setstate__(state)
+ self.verifyDict = {}
+ self['MemberList'] = fakeContactList
+
+class MassivePlatform(AbstractUserDict):
+ def __init__(self, *args, **kwargs):
+ super(MassivePlatform, self).__init__(*args, **kwargs)
+ self.__setstate__(None)
+ def __setstate__(self, state):
+ super(MassivePlatform, self).__setstate__(state)
+ self['MemberList'] = fakeContactList
+
+class Chatroom(AbstractUserDict):
+ def __init__(self, *args, **kwargs):
+ super(Chatroom, self).__init__(*args, **kwargs)
+ memberList = ContactList()
+ userName = self.get('UserName', '')
+ refSelf = ref(self)
+ def init_fn(parentList, d):
+ d.chatroom = refSelf() or \
+ parentList.core.search_chatrooms(userName=userName)
+ memberList.set_default_value(init_fn, ChatroomMember)
+ if 'MemberList' in self:
+ for member in self.memberList:
+ memberList.append(member)
+ self['MemberList'] = memberList
+ @property
+ def core(self):
+ return getattr(self, '_core', lambda: fakeItchat)() or fakeItchat
+ @core.setter
+ def core(self, value):
+ self._core = ref(value)
+ self.memberList.core = value
+ for member in self.memberList:
+ member.core = value
+ def update(self, detailedMember=False):
+ r = self.core.update_chatroom(self.userName, detailedMember)
+ if r:
+ update_info_dict(self, r)
+ self['MemberList'] = r['MemberList']
+ return r
+ def set_alias(self, alias):
+ return self.core.set_chatroom_name(self.userName, alias)
+ def set_pinned(self, isPinned=True):
+ return self.core.set_pinned(self.userName, isPinned)
+ def delete_member(self, userName):
+ return self.core.delete_member_from_chatroom(self.userName, userName)
+ def add_member(self, userName):
+ return self.core.add_member_into_chatroom(self.userName, userName)
+ def search_member(self, name=None, userName=None, remarkName=None, nickName=None,
+ wechatAccount=None):
+ with self.core.storageClass.updateLock:
+ if (name or userName or remarkName or nickName or wechatAccount) is None:
+ return None
+ elif userName: # return the only userName match
+ for m in self.memberList:
+ if m.userName == userName:
+ return copy.deepcopy(m)
+ else:
+ matchDict = {
+ 'RemarkName' : remarkName,
+ 'NickName' : nickName,
+ 'Alias' : wechatAccount, }
+ for k in ('RemarkName', 'NickName', 'Alias'):
+ if matchDict[k] is None:
+ del matchDict[k]
+ if name: # select based on name
+ contact = []
+ for m in self.memberList:
+ if any([m.get(k) == name for k in ('RemarkName', 'NickName', 'Alias')]):
+ contact.append(m)
+ else:
+ contact = self.memberList[:]
+ if matchDict: # select again based on matchDict
+ friendList = []
+ for m in contact:
+ if all([m.get(k) == v for k, v in matchDict.items()]):
+ friendList.append(m)
+ return copy.deepcopy(friendList)
+ else:
+ return copy.deepcopy(contact)
+ def __setstate__(self, state):
+ super(Chatroom, self).__setstate__(state)
+ if not 'MemberList' in self:
+ self['MemberList'] = fakeContactList
+
+class ChatroomMember(AbstractUserDict):
+ def __init__(self, *args, **kwargs):
+ super(AbstractUserDict, self).__init__(*args, **kwargs)
+ self.__setstate__(None)
+ @property
+ def chatroom(self):
+ r = getattr(self, '_chatroom', lambda: fakeChatroom)()
+ if r is None:
+ userName = getattr(self, '_chatroomUserName', '')
+ r = self.core.search_chatrooms(userName=userName)
+ if isinstance(r, dict):
+ self.chatroom = r
+ return r or fakeChatroom
+ @chatroom.setter
+ def chatroom(self, value):
+ if isinstance(value, dict) and 'UserName' in value:
+ self._chatroom = ref(value)
+ self._chatroomUserName = value['UserName']
+ def get_head_image(self, imageDir=None):
+ return self.core.get_head_img(self.userName, self.chatroom.userName, picDir=imageDir)
+ def delete_member(self, userName):
+ return self.core.delete_member_from_chatroom(self.chatroom.userName, self.userName)
+ def send_raw_msg(self, msgType, content):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not send message directly' % \
+ self.__class__.__name__, }, })
+ def send_msg(self, msg='Test Message'):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not send message directly' % \
+ self.__class__.__name__, }, })
+ def send_file(self, fileDir, mediaId=None):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not send message directly' % \
+ self.__class__.__name__, }, })
+ def send_image(self, fileDir, mediaId=None):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not send message directly' % \
+ self.__class__.__name__, }, })
+ def send_video(self, fileDir=None, mediaId=None):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not send message directly' % \
+ self.__class__.__name__, }, })
+ def send(self, msg, mediaId=None):
+ return ReturnValue({'BaseResponse': {
+ 'Ret': -1006,
+ 'ErrMsg': '%s can not send message directly' % \
+ self.__class__.__name__, }, })
+ def __setstate__(self, state):
+ super(ChatroomMember, self).__setstate__(state)
+ self['MemberList'] = fakeContactList
+
+def wrap_user_dict(d):
+ userName = d.get('UserName')
+ if '@@' in userName:
+ r = Chatroom(d)
+ elif d.get('VerifyFlag', 8) & 8 == 0:
+ r = User(d)
+ else:
+ r = MassivePlatform(d)
+ return r
+
+fakeItchat = UnInitializedItchat()
+fakeContactList = ContactList()
+fakeChatroom = Chatroom()
diff --git a/lib/itchat/utils.py b/lib/itchat/utils.py
new file mode 100644
index 0000000..c5dfe24
--- /dev/null
+++ b/lib/itchat/utils.py
@@ -0,0 +1,163 @@
+import re, os, sys, subprocess, copy, traceback, logging
+
+try:
+ from HTMLParser import HTMLParser
+except ImportError:
+ from html.parser import HTMLParser
+try:
+ from urllib import quote as _quote
+ quote = lambda n: _quote(n.encode('utf8', 'replace'))
+except ImportError:
+ from urllib.parse import quote
+
+import requests
+
+from . import config
+
+logger = logging.getLogger('itchat')
+
+emojiRegex = re.compile(r'')
+htmlParser = HTMLParser()
+if not hasattr(htmlParser, 'unescape'):
+ import html
+ htmlParser.unescape = html.unescape
+ # FIX Python 3.9 HTMLParser.unescape is removed. See https://docs.python.org/3.9/whatsnew/3.9.html
+try:
+ b = u'\u2588'
+ sys.stdout.write(b + '\r')
+ sys.stdout.flush()
+except UnicodeEncodeError:
+ BLOCK = 'MM'
+else:
+ BLOCK = b
+friendInfoTemplate = {}
+for k in ('UserName', 'City', 'DisplayName', 'PYQuanPin', 'RemarkPYInitial', 'Province',
+ 'KeyWord', 'RemarkName', 'PYInitial', 'EncryChatRoomId', 'Alias', 'Signature',
+ 'NickName', 'RemarkPYQuanPin', 'HeadImgUrl'):
+ friendInfoTemplate[k] = ''
+for k in ('UniFriend', 'Sex', 'AppAccountFlag', 'VerifyFlag', 'ChatRoomId', 'HideInputBarFlag',
+ 'AttrStatus', 'SnsFlag', 'MemberCount', 'OwnerUin', 'ContactFlag', 'Uin',
+ 'StarFriend', 'Statues'):
+ friendInfoTemplate[k] = 0
+friendInfoTemplate['MemberList'] = []
+
+def clear_screen():
+ os.system('cls' if config.OS == 'Windows' else 'clear')
+
+def emoji_formatter(d, k):
+ ''' _emoji_deebugger is for bugs about emoji match caused by wechat backstage
+ like :face with tears of joy: will be replaced with :cat face with tears of joy:
+ '''
+ def _emoji_debugger(d, k):
+ s = d[k].replace('') # fix missing bug
+ def __fix_miss_match(m):
+ return '' % ({
+ '1f63c': '1f601', '1f639': '1f602', '1f63a': '1f603',
+ '1f4ab': '1f616', '1f64d': '1f614', '1f63b': '1f60d',
+ '1f63d': '1f618', '1f64e': '1f621', '1f63f': '1f622',
+ }.get(m.group(1), m.group(1)))
+ return emojiRegex.sub(__fix_miss_match, s)
+ def _emoji_formatter(m):
+ s = m.group(1)
+ if len(s) == 6:
+ return ('\\U%s\\U%s'%(s[:2].rjust(8, '0'), s[2:].rjust(8, '0'))
+ ).encode('utf8').decode('unicode-escape', 'replace')
+ elif len(s) == 10:
+ return ('\\U%s\\U%s'%(s[:5].rjust(8, '0'), s[5:].rjust(8, '0'))
+ ).encode('utf8').decode('unicode-escape', 'replace')
+ else:
+ return ('\\U%s'%m.group(1).rjust(8, '0')
+ ).encode('utf8').decode('unicode-escape', 'replace')
+ d[k] = _emoji_debugger(d, k)
+ d[k] = emojiRegex.sub(_emoji_formatter, d[k])
+
+def msg_formatter(d, k):
+ emoji_formatter(d, k)
+ d[k] = d[k].replace('
', '\n')
+ d[k] = htmlParser.unescape(d[k])
+
+def check_file(fileDir):
+ try:
+ with open(fileDir):
+ pass
+ return True
+ except:
+ return False
+
+def print_qr(fileDir):
+ if config.OS == 'Darwin':
+ subprocess.call(['open', fileDir])
+ elif config.OS == 'Linux':
+ subprocess.call(['xdg-open', fileDir])
+ else:
+ os.startfile(fileDir)
+
+def print_cmd_qr(qrText, white=BLOCK, black=' ', enableCmdQR=True):
+ blockCount = int(enableCmdQR)
+ if abs(blockCount) == 0:
+ blockCount = 1
+ white *= abs(blockCount)
+ if blockCount < 0:
+ white, black = black, white
+ sys.stdout.write(' '*50 + '\r')
+ sys.stdout.flush()
+ qr = qrText.replace('0', white).replace('1', black)
+ sys.stdout.write(qr)
+ sys.stdout.flush()
+
+def struct_friend_info(knownInfo):
+ member = copy.deepcopy(friendInfoTemplate)
+ for k, v in copy.deepcopy(knownInfo).items(): member[k] = v
+ return member
+
+def search_dict_list(l, key, value):
+ ''' Search a list of dict
+ * return dict with specific value & key '''
+ for i in l:
+ if i.get(key) == value:
+ return i
+
+def print_line(msg, oneLine = False):
+ if oneLine:
+ sys.stdout.write(' '*40 + '\r')
+ sys.stdout.flush()
+ else:
+ sys.stdout.write('\n')
+ sys.stdout.write(msg.encode(sys.stdin.encoding or 'utf8', 'replace'
+ ).decode(sys.stdin.encoding or 'utf8', 'replace'))
+ sys.stdout.flush()
+
+def test_connect(retryTime=5):
+ for i in range(retryTime):
+ try:
+ r = requests.get(config.BASE_URL)
+ return True
+ except:
+ if i == retryTime - 1:
+ logger.error(traceback.format_exc())
+ return False
+
+def contact_deep_copy(core, contact):
+ with core.storageClass.updateLock:
+ return copy.deepcopy(contact)
+
+def get_image_postfix(data):
+ data = data[:20]
+ if b'GIF' in data:
+ return 'gif'
+ elif b'PNG' in data:
+ return 'png'
+ elif b'JFIF' in data:
+ return 'jpg'
+ return ''
+
+def update_info_dict(oldInfoDict, newInfoDict):
+ ''' only normal values will be updated here
+ because newInfoDict is normal dict, so it's not necessary to consider templates
+ '''
+ for k, v in newInfoDict.items():
+ if any((isinstance(v, t) for t in (tuple, list, dict))):
+ pass # these values will be updated somewhere else
+ elif oldInfoDict.get(k) is None or v not in (None, '', '0', 0):
+ oldInfoDict[k] = v
\ No newline at end of file
diff --git a/nixpacks.toml b/nixpacks.toml
new file mode 100644
index 0000000..627d2e7
--- /dev/null
+++ b/nixpacks.toml
@@ -0,0 +1,7 @@
+[phases.setup]
+nixPkgs = ['python310']
+cmds = ['apt-get update','apt-get install -y --no-install-recommends ffmpeg espeak libavcodec-extra']
+[phases.install]
+cmds = ['python -m venv /opt/venv && . /opt/venv/bin/activate && pip install -r requirements.txt && pip install -r requirements-optional.txt']
+[start]
+cmd = "python ./app.py"
\ No newline at end of file
diff --git a/plugins/README.md b/plugins/README.md
new file mode 100644
index 0000000..2a44615
--- /dev/null
+++ b/plugins/README.md
@@ -0,0 +1,273 @@
+**Table of Content**
+
+- [插件化初衷](#插件化初衷)
+- [插件安装方法](#插件安装方法)
+- [插件化实现](#插件化实现)
+- [插件编写示例](#插件编写示例)
+- [插件设计建议](#插件设计建议)
+
+## 插件化初衷
+
+之前未插件化的代码耦合程度高,如果要定制一些个性化功能(如流量控制、接入`NovelAI`画图平台等),需要了解代码主体,避免影响到其他的功能。多个功能同时存在时,无法调整功能的优先级顺序,功能配置项也非常混乱。
+
+此时插件化应声而出。
+
+**插件化**: 在保证主体功能是ChatGPT的前提下,我们推荐将主体功能外的功能利用插件的方式实现。
+
+- [x] 可根据功能需要,下载不同插件。
+- [x] 插件开发成本低,仅需了解插件触发事件,并按照插件定义接口编写插件。
+- [x] 插件化能够自由开关和调整优先级。
+- [x] 每个插件可在插件文件夹内维护独立的配置文件,方便代码的测试和调试,可以在独立的仓库开发插件。
+
+## 插件安装方法
+
+在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。
+
+- 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。
+
+- 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。
+
+ 安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。
+
+ - 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui
+
+ - 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git
+
+ 在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。
+
+安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。
+
+## 插件化实现
+
+插件化实现是在收到消息到发送回复的各个步骤之间插入触发事件实现的。
+
+### 消息处理过程
+
+在了解插件触发事件前,首先需要了解程序收到消息到发送回复的整个过程。
+
+插件化版本中,消息处理过程可以分为4个步骤:
+```
+ 1.收到消息 ---> 2.产生回复 ---> 3.包装回复 ---> 4.发送回复
+```
+
+以下是它们的默认处理逻辑(太长不看,可跳到[插件编写示例](#插件编写示例)):
+
+**注意以下包含的代码是`v1.1.0`中的片段,已过时,只可用于理解事件,最新的默认代码逻辑请参考[chat_channel](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/chat_channel.py)**
+
+#### 1. 收到消息
+
+负责接收用户消息,根据用户的配置,判断本条消息是否触发机器人。如果触发,则会判断该消息的类型(声音、文本、画图命令等),将消息包装成如下的`Context`交付给下一个步骤。
+
+```python
+ class ContextType (Enum):
+ TEXT = 1 # 文本消息
+ VOICE = 2 # 音频消息
+ IMAGE_CREATE = 3 # 创建图片命令
+ class Context:
+ def __init__(self, type : ContextType = None , content = None, kwargs = dict()):
+ self.type = type
+ self.content = content
+ self.kwargs = kwargs
+ def __getitem__(self, key):
+ return self.kwargs[key]
+```
+
+`Context`中除了存放消息类型和内容外,还存放了一些与会话相关的参数。
+
+例如,当收到用户私聊消息时,会存放以下的会话参数。
+
+```python
+ context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
+```
+
+- `isgroup`: `Context`是否是群聊消息。
+- `msg`: `itchat`中原始的消息对象。
+- `receiver`: 需要回复消息的对象ID。
+- `session_id`: 会话ID(一般是发送触发bot消息的用户ID,如果在群聊中并且`conf`里设置了`group_chat_in_one_session`,那么此处便是群聊ID)
+
+#### 2. 产生回复
+
+处理消息并产生回复。目前默认处理逻辑是根据`Context`的类型交付给对应的bot,并产生回复`Reply`。 如果本步骤没有产生任何回复,那么会跳过之后的所有步骤。
+
+```python
+ if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:
+ reply = super().build_reply_content(context.content, context) #文字跟画图交付给chatgpt
+ elif context.type == ContextType.VOICE: # 声音先进行语音转文字后,修改Context类型为文字后,再交付给chatgpt
+ cmsg = context['msg']
+ cmsg.prepare()
+ file_name = context.content
+ reply = super().build_voice_to_text(file_name)
+ if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
+ context.content = reply.content # 语音转文字后,将文字内容作为新的context
+ context.type = ContextType.TEXT
+ reply = super().build_reply_content(context.content, context)
+ if reply.type == ReplyType.TEXT:
+ if conf().get('voice_reply_voice'):
+ reply = super().build_text_to_voice(reply.content)
+```
+
+回复`Reply`的定义如下所示,它允许Bot可以回复多类不同的消息。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。
+
+```python
+ class ReplyType(Enum):
+ TEXT = 1 # 文本
+ VOICE = 2 # 音频文件
+ IMAGE = 3 # 图片文件
+ IMAGE_URL = 4 # 图片URL
+
+ INFO = 9
+ ERROR = 10
+ class Reply:
+ def __init__(self, type : ReplyType = None , content = None):
+ self.type = type
+ self.content = content
+```
+
+#### 3. 装饰回复
+
+根据`Context`和回复`Reply`的类型,对回复的内容进行装饰。目前的装饰有以下两种:
+
+- `TEXT`文本回复:如果这次消息需要的回复是`VOICE`,进行文字转语音回复之后再次装饰。 否则根据是否在群聊中来决定是艾特接收方还是添加回复的前缀。
+
+- `INFO`或`ERROR`类型,会在消息前添加对应的系统提示字样。
+
+如下是默认逻辑的代码:
+
+```python
+ if reply.type == ReplyType.TEXT:
+ reply_text = reply.content
+ if context.get('desire_rtype') == ReplyType.VOICE:
+ reply = super().build_text_to_voice(reply.content)
+ return self._decorate_reply(context, reply)
+ if context['isgroup']:
+ reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
+ reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
+ else:
+ reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
+ reply.content = reply_text
+ elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
+ reply.content = str(reply.type)+":\n" + reply.content
+```
+
+#### 4. 发送回复
+
+根据`Reply`的类型,默认逻辑调用不同的发送函数发送回复给接收方`context["receiver"]`。
+
+### 插件触发事件
+
+主程序目前会在各个消息步骤间触发事件,监听相应事件的插件会按照优先级,顺序调用事件处理函数。
+
+目前支持三类触发事件:
+```
+1.收到消息
+---> `ON_HANDLE_CONTEXT`
+2.产生回复
+---> `ON_DECORATE_REPLY`
+3.装饰回复
+---> `ON_SEND_REPLY`
+4.发送回复
+```
+
+触发事件会产生事件的上下文`EventContext`,它包含了以下信息:
+
+`EventContext(Event事件类型, {'channel' : 消息channel, 'context': Context, 'reply': Reply})`
+
+插件处理函数可通过修改`EventContext`中的`context`和`reply`来实现功能。
+
+## 插件编写示例
+
+以`plugins/hello`为例,其中编写了一个简单的`Hello`插件。
+
+### 1. 创建插件
+
+在`plugins`目录下创建一个插件文件夹`hello`。然后,在该文件夹中创建``__init__.py``文件,在``__init__.py``中将其他编写的模块文件导入。在程序启动时,插件管理器会读取``__init__.py``的所有内容。
+
+```
+plugins/
+└── hello
+ ├── __init__.py
+ └── hello.py
+```
+
+``__init__.py``的内容:
+```
+from .hello import *
+```
+
+### 2. 编写插件类
+
+在`hello.py`文件中,创建插件类,它继承自`Plugin`。
+
+在类定义之前需要使用`@plugins.register`装饰器注册插件,并填写插件的相关信息,其中`desire_priority`表示插件默认的优先级,越大优先级越高。初次加载插件后可在`plugins/plugins.json`中修改插件优先级。
+
+并在`__init__`中绑定你编写的事件处理函数。
+
+`Hello`插件为事件`ON_HANDLE_CONTEXT`绑定了一个处理函数`on_handle_context`,它表示之后每次生成回复前,都会由`on_handle_context`先处理。
+
+PS: `ON_HANDLE_CONTEXT`是最常用的事件,如果要根据不同的消息来生成回复,就用它。
+
+```python
+@plugins.register(name="Hello", desc="A simple plugin that says hello", version="0.1", author="lanvent", desire_priority= -1)
+class Hello(Plugin):
+ def __init__(self):
+ super().__init__()
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+ logger.info("[Hello] inited")
+```
+
+### 3. 编写事件处理函数
+
+#### 修改事件上下文
+
+事件处理函数接收一个`EventContext`对象`e_context`作为参数。`e_context`包含了事件相关信息,利用`e_context['key']`来访问这些信息。
+
+`EventContext(Event事件类型, {'channel' : 消息channel, 'context': Context, 'reply': Reply})`
+
+处理函数中通过修改`e_context`对象中的事件相关信息来实现所需功能,比如更改`e_context['reply']`中的内容可以修改回复。
+
+#### 决定是否交付给下个插件或默认逻辑
+
+在处理函数结束时,还需要设置`e_context`对象的`action`属性,它决定如何继续处理事件。目前有以下三种处理方式:
+
+- `EventAction.CONTINUE`: 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑。
+- `EventAction.BREAK`: 事件结束,不再给下个插件处理,交付给默认的处理逻辑。
+- `EventAction.BREAK_PASS`: 事件结束,不再给下个插件处理,跳过默认的处理逻辑。
+
+#### 示例处理函数
+
+`Hello`插件处理`Context`类型为`TEXT`的消息:
+
+- 如果内容是`Hello`,就将回复设置为`Hello+用户昵称`,并跳过之后的插件和默认逻辑。
+- 如果内容是`End`,就将`Context`的类型更改为`IMAGE_CREATE`,并让事件继续,如果最终交付到默认逻辑,会调用默认的画图Bot来画画。
+
+```python
+ def on_handle_context(self, e_context: EventContext):
+ if e_context['context'].type != ContextType.TEXT:
+ return
+ content = e_context['context'].content
+ if content == "Hello":
+ reply = Reply()
+ reply.type = ReplyType.TEXT
+ msg:ChatMessage = e_context['context']['msg']
+ if e_context['context']['isgroup']:
+ reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
+ else:
+ reply.content = f"Hello, {msg.from_user_nickname}"
+ e_context['reply'] = reply
+ e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+ if content == "End":
+ # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World"
+ e_context['context'].type = ContextType.IMAGE_CREATE
+ content = "The World"
+ e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
+```
+
+## 插件设计建议
+
+- 尽情将你想要的个性化功能设计为插件。
+- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。
+
+ 在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。
+
+- 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。
+- 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。
diff --git a/plugins/__init__.py b/plugins/__init__.py
new file mode 100644
index 0000000..d515edb
--- /dev/null
+++ b/plugins/__init__.py
@@ -0,0 +1,9 @@
+from .event import *
+from .plugin import *
+from .plugin_manager import PluginManager
+
+instance = PluginManager()
+
+register = instance.register
+# load_plugins = instance.load_plugins
+# emit_event = instance.emit_event
diff --git a/plugins/banwords/.gitignore b/plugins/banwords/.gitignore
new file mode 100644
index 0000000..a6593bf
--- /dev/null
+++ b/plugins/banwords/.gitignore
@@ -0,0 +1 @@
+banwords.txt
\ No newline at end of file
diff --git a/plugins/banwords/README.md b/plugins/banwords/README.md
new file mode 100644
index 0000000..39517f6
--- /dev/null
+++ b/plugins/banwords/README.md
@@ -0,0 +1,27 @@
+
+## 插件描述
+
+简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。
+
+使用前将`config.json.template`复制为`config.json`,并自行配置。
+
+目前插件对消息的默认处理行为有如下两种:
+
+- `ignore` : 无视这条消息。
+- `replace` : 将消息中的敏感词替换成"*",并回复违规。
+
+```json
+ "action": "replace",
+ "reply_filter": true,
+ "reply_action": "ignore"
+```
+
+在以上配置项中:
+
+- `action`: 对用户消息的默认处理行为
+- `reply_filter`: 是否对ChatGPT的回复也进行敏感词过滤
+- `reply_action`: 如果开启了回复过滤,对回复的默认处理行为
+
+## 致谢
+
+搜索功能实现来自https://github.com/toolgood/ToolGood.Words
\ No newline at end of file
diff --git a/plugins/banwords/__init__.py b/plugins/banwords/__init__.py
new file mode 100644
index 0000000..503a563
--- /dev/null
+++ b/plugins/banwords/__init__.py
@@ -0,0 +1 @@
+from .banwords import *
diff --git a/plugins/banwords/banwords.py b/plugins/banwords/banwords.py
new file mode 100644
index 0000000..2a33a5a
--- /dev/null
+++ b/plugins/banwords/banwords.py
@@ -0,0 +1,100 @@
+# encoding:utf-8
+
+import json
+import os
+
+import plugins
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from plugins import *
+
+from .lib.WordsSearch import WordsSearch
+
+
+@plugins.register(
+ name="Banwords",
+ desire_priority=100,
+ hidden=True,
+ desc="判断消息中是否有敏感词、决定是否回复。",
+ version="1.0",
+ author="lanvent",
+)
+class Banwords(Plugin):
+ def __init__(self):
+ super().__init__()
+ try:
+ # load config
+ conf = super().load_config()
+ curdir = os.path.dirname(__file__)
+ if not conf:
+ # 配置不存在则写入默认配置
+ config_path = os.path.join(curdir, "config.json")
+ if not os.path.exists(config_path):
+ conf = {"action": "ignore"}
+ with open(config_path, "w") as f:
+ json.dump(conf, f, indent=4)
+
+ self.searchr = WordsSearch()
+ self.action = conf["action"]
+ banwords_path = os.path.join(curdir, "banwords.txt")
+ with open(banwords_path, "r", encoding="utf-8") as f:
+ words = []
+ for line in f:
+ word = line.strip()
+ if word:
+ words.append(word)
+ self.searchr.SetKeywords(words)
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+ if conf.get("reply_filter", True):
+ self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply
+ self.reply_action = conf.get("reply_action", "ignore")
+ logger.info("[Banwords] inited")
+ except Exception as e:
+ logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
+ raise e
+
+ def on_handle_context(self, e_context: EventContext):
+ if e_context["context"].type not in [
+ ContextType.TEXT,
+ ContextType.IMAGE_CREATE,
+ ]:
+ return
+
+ content = e_context["context"].content
+ logger.debug("[Banwords] on_handle_context. content: %s" % content)
+ if self.action == "ignore":
+ f = self.searchr.FindFirst(content)
+ if f:
+ logger.info("[Banwords] %s in message" % f["Keyword"])
+ e_context.action = EventAction.BREAK_PASS
+ return
+ elif self.action == "replace":
+ if self.searchr.ContainsAny(content):
+ reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content))
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+
+ def on_decorate_reply(self, e_context: EventContext):
+ if e_context["reply"].type not in [ReplyType.TEXT]:
+ return
+
+ reply = e_context["reply"]
+ content = reply.content
+ if self.reply_action == "ignore":
+ f = self.searchr.FindFirst(content)
+ if f:
+ logger.info("[Banwords] %s in reply" % f["Keyword"])
+ e_context["reply"] = None
+ e_context.action = EventAction.BREAK_PASS
+ return
+ elif self.reply_action == "replace":
+ if self.searchr.ContainsAny(content):
+ reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content))
+ e_context["reply"] = reply
+ e_context.action = EventAction.CONTINUE
+ return
+
+ def get_help_text(self, **kwargs):
+ return "过滤消息中的敏感词。"
diff --git a/plugins/banwords/banwords.txt.template b/plugins/banwords/banwords.txt.template
new file mode 100644
index 0000000..9b2e8ed
--- /dev/null
+++ b/plugins/banwords/banwords.txt.template
@@ -0,0 +1,3 @@
+nipples
+pennis
+法轮功
\ No newline at end of file
diff --git a/plugins/banwords/config.json.template b/plugins/banwords/config.json.template
new file mode 100644
index 0000000..3117a83
--- /dev/null
+++ b/plugins/banwords/config.json.template
@@ -0,0 +1,5 @@
+{
+ "action": "replace",
+ "reply_filter": true,
+ "reply_action": "ignore"
+}
diff --git a/plugins/banwords/lib/WordsSearch.py b/plugins/banwords/lib/WordsSearch.py
new file mode 100644
index 0000000..d41d6e7
--- /dev/null
+++ b/plugins/banwords/lib/WordsSearch.py
@@ -0,0 +1,250 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# ToolGood.Words.WordsSearch.py
+# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words
+# Licensed under the Apache License 2.0
+# 更新日志
+# 2020.04.06 第一次提交
+# 2020.05.16 修改,支持大于0xffff的字符
+
+__all__ = ['WordsSearch']
+__author__ = 'Lin Zhijun'
+__date__ = '2020.05.16'
+
+class TrieNode():
+ def __init__(self):
+ self.Index = 0
+ self.Index = 0
+ self.Layer = 0
+ self.End = False
+ self.Char = ''
+ self.Results = []
+ self.m_values = {}
+ self.Failure = None
+ self.Parent = None
+
+ def Add(self,c):
+ if c in self.m_values :
+ return self.m_values[c]
+ node = TrieNode()
+ node.Parent = self
+ node.Char = c
+ self.m_values[c] = node
+ return node
+
+ def SetResults(self,index):
+ if (self.End == False):
+ self.End = True
+ self.Results.append(index)
+
+class TrieNode2():
+ def __init__(self):
+ self.End = False
+ self.Results = []
+ self.m_values = {}
+ self.minflag = 0xffff
+ self.maxflag = 0
+
+ def Add(self,c,node3):
+ if (self.minflag > c):
+ self.minflag = c
+ if (self.maxflag < c):
+ self.maxflag = c
+ self.m_values[c] = node3
+
+ def SetResults(self,index):
+ if (self.End == False) :
+ self.End = True
+ if (index in self.Results )==False :
+ self.Results.append(index)
+
+ def HasKey(self,c):
+ return c in self.m_values
+
+
+ def TryGetValue(self,c):
+ if (self.minflag <= c and self.maxflag >= c):
+ if c in self.m_values:
+ return self.m_values[c]
+ return None
+
+
+class WordsSearch():
+ def __init__(self):
+ self._first = {}
+ self._keywords = []
+ self._indexs=[]
+
+ def SetKeywords(self,keywords):
+ self._keywords = keywords
+ self._indexs=[]
+ for i in range(len(keywords)):
+ self._indexs.append(i)
+
+ root = TrieNode()
+ allNodeLayer={}
+
+ for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++)
+ p = self._keywords[i]
+ nd = root
+ for j in range(len(p)): # for (j = 0; j < p.length; j++)
+ nd = nd.Add(ord(p[j]))
+ if (nd.Layer == 0):
+ nd.Layer = j + 1
+ if nd.Layer in allNodeLayer:
+ allNodeLayer[nd.Layer].append(nd)
+ else:
+ allNodeLayer[nd.Layer]=[]
+ allNodeLayer[nd.Layer].append(nd)
+ nd.SetResults(i)
+
+
+ allNode = []
+ allNode.append(root)
+ for key in allNodeLayer.keys():
+ for nd in allNodeLayer[key]:
+ allNode.append(nd)
+ allNodeLayer=None
+
+ for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++)
+ if i==0 :
+ continue
+ nd=allNode[i]
+ nd.Index = i
+ r = nd.Parent.Failure
+ c = nd.Char
+ while (r != None and (c in r.m_values)==False):
+ r = r.Failure
+ if (r == None):
+ nd.Failure = root
+ else:
+ nd.Failure = r.m_values[c]
+ for key2 in nd.Failure.Results :
+ nd.SetResults(key2)
+ root.Failure = root
+
+ allNode2 = []
+ for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++)
+ allNode2.append( TrieNode2())
+
+ for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++)
+ oldNode = allNode[i]
+ newNode = allNode2[i]
+
+ for key in oldNode.m_values :
+ index = oldNode.m_values[key].Index
+ newNode.Add(key, allNode2[index])
+
+ for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++)
+ item = oldNode.Results[index]
+ newNode.SetResults(item)
+
+ oldNode=oldNode.Failure
+ while oldNode != root:
+ for key in oldNode.m_values :
+ if (newNode.HasKey(key) == False):
+ index = oldNode.m_values[key].Index
+ newNode.Add(key, allNode2[index])
+ for index in range(len(oldNode.Results)):
+ item = oldNode.Results[index]
+ newNode.SetResults(item)
+ oldNode=oldNode.Failure
+ allNode = None
+ root = None
+
+ # first = []
+ # for index in range(65535):# for (index = 0; index < 0xffff; index++)
+ # first.append(None)
+
+ # for key in allNode2[0].m_values :
+ # first[key] = allNode2[0].m_values[key]
+
+ self._first = allNode2[0]
+
+
+ def FindFirst(self,text):
+ ptr = None
+ for index in range(len(text)): # for (index = 0; index < text.length; index++)
+ t =ord(text[index]) # text.charCodeAt(index)
+ tn = None
+ if (ptr == None):
+ tn = self._first.TryGetValue(t)
+ else:
+ tn = ptr.TryGetValue(t)
+ if (tn==None):
+ tn = self._first.TryGetValue(t)
+
+
+ if (tn != None):
+ if (tn.End):
+ item = tn.Results[0]
+ keyword = self._keywords[item]
+ return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }
+ ptr = tn
+ return None
+
+ def FindAll(self,text):
+ ptr = None
+ list = []
+
+ for index in range(len(text)): # for (index = 0; index < text.length; index++)
+ t =ord(text[index]) # text.charCodeAt(index)
+ tn = None
+ if (ptr == None):
+ tn = self._first.TryGetValue(t)
+ else:
+ tn = ptr.TryGetValue(t)
+ if (tn==None):
+ tn = self._first.TryGetValue(t)
+
+
+ if (tn != None):
+ if (tn.End):
+ for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++)
+ item = tn.Results[j]
+ keyword = self._keywords[item]
+ list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] })
+ ptr = tn
+ return list
+
+
+ def ContainsAny(self,text):
+ ptr = None
+ for index in range(len(text)): # for (index = 0; index < text.length; index++)
+ t =ord(text[index]) # text.charCodeAt(index)
+ tn = None
+ if (ptr == None):
+ tn = self._first.TryGetValue(t)
+ else:
+ tn = ptr.TryGetValue(t)
+ if (tn==None):
+ tn = self._first.TryGetValue(t)
+
+ if (tn != None):
+ if (tn.End):
+ return True
+ ptr = tn
+ return False
+
+ def Replace(self,text, replaceChar = '*'):
+ result = list(text)
+
+ ptr = None
+ for i in range(len(text)): # for (i = 0; i < text.length; i++)
+ t =ord(text[i]) # text.charCodeAt(index)
+ tn = None
+ if (ptr == None):
+ tn = self._first.TryGetValue(t)
+ else:
+ tn = ptr.TryGetValue(t)
+ if (tn==None):
+ tn = self._first.TryGetValue(t)
+
+ if (tn != None):
+ if (tn.End):
+ maxLength = len( self._keywords[tn.Results[0]])
+ start = i + 1 - maxLength
+ for j in range(start,i+1): # for (j = start; j <= i; j++)
+ result[j] = replaceChar
+ ptr = tn
+ return ''.join(result)
\ No newline at end of file
diff --git a/plugins/bdunit/README.md b/plugins/bdunit/README.md
new file mode 100644
index 0000000..a2f2c78
--- /dev/null
+++ b/plugins/bdunit/README.md
@@ -0,0 +1,30 @@
+## 插件说明
+
+利用百度UNIT实现智能对话
+
+- 1.解决问题:chatgpt无法处理的指令,交给百度UNIT处理如:天气,日期时间,数学运算等
+- 2.如问时间:现在几点钟,今天几号
+- 3.如问天气:明天广州天气怎么样,这个周末深圳会不会下雨
+- 4.如问数学运算:23+45=多少,100-23=多少,35转化为二进制是多少?
+
+## 使用说明
+
+### 获取apikey
+
+在百度UNIT官网上自己创建应用,申请百度机器人,可以把预先训练好的模型导入到自己的应用中,
+
+see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087095fc10c8377aaf https://console.bce.baidu.com/ai平台申请
+
+### 配置文件
+
+将文件夹中`config.json.template`复制为`config.json`。
+
+在其中填写百度UNIT官网上获取应用的API Key和Secret Key
+
+``` json
+ {
+ "service_id": "s...", #"机器人ID"
+ "api_key": "",
+ "secret_key": ""
+ }
+```
\ No newline at end of file
diff --git a/plugins/bdunit/__init__.py b/plugins/bdunit/__init__.py
new file mode 100644
index 0000000..28f44b4
--- /dev/null
+++ b/plugins/bdunit/__init__.py
@@ -0,0 +1 @@
+from .bdunit import *
diff --git a/plugins/bdunit/bdunit.py b/plugins/bdunit/bdunit.py
new file mode 100644
index 0000000..33194e3
--- /dev/null
+++ b/plugins/bdunit/bdunit.py
@@ -0,0 +1,252 @@
+# encoding:utf-8
+import json
+import os
+import uuid
+from uuid import getnode as get_mac
+
+import requests
+
+import plugins
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from plugins import *
+
+"""利用百度UNIT实现智能对话
+ 如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
+"""
+
+
+@plugins.register(
+ name="BDunit",
+ desire_priority=0,
+ hidden=True,
+ desc="Baidu unit bot system",
+ version="0.1",
+ author="jackson",
+)
+class BDunit(Plugin):
+ def __init__(self):
+ super().__init__()
+ try:
+ conf = super().load_config()
+ if not conf:
+ raise Exception("config.json not found")
+ self.service_id = conf["service_id"]
+ self.api_key = conf["api_key"]
+ self.secret_key = conf["secret_key"]
+ self.access_token = self.get_token()
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+ logger.info("[BDunit] inited")
+ except Exception as e:
+ logger.warn("[BDunit] init failed, ignore ")
+ raise e
+
+ def on_handle_context(self, e_context: EventContext):
+ if e_context["context"].type != ContextType.TEXT:
+ return
+
+ content = e_context["context"].content
+ logger.debug("[BDunit] on_handle_context. content: %s" % content)
+ parsed = self.getUnit2(content)
+ intent = self.getIntent(parsed)
+ if intent: # 找到意图
+ logger.debug("[BDunit] Baidu_AI Intent= %s", intent)
+ reply = Reply()
+ reply.type = ReplyType.TEXT
+ reply.content = self.getSay(parsed)
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+ else:
+ e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
+
+ def get_help_text(self, **kwargs):
+ help_text = "本插件会处理询问实时日期时间,天气,数学运算等问题,这些技能由您的百度智能对话UNIT决定\n"
+ return help_text
+
+ def get_token(self):
+ """获取访问百度UUNIT 的access_token
+ #param api_key: UNIT apk_key
+ #param secret_key: UNIT secret_key
+ Returns:
+ string: access_token
+ """
+ url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
+ payload = ""
+ headers = {"Content-Type": "application/json", "Accept": "application/json"}
+
+ response = requests.request("POST", url, headers=headers, data=payload)
+
+ # print(response.text)
+ return response.json()["access_token"]
+
+ def getUnit(self, query):
+ """
+ NLU 解析version 3.0
+ :param query: 用户的指令字符串
+ :returns: UNIT 解析结果。如果解析失败,返回 None
+ """
+
+ url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
+ request = {
+ "query": query,
+ "user_id": str(get_mac())[:32],
+ "terminal_id": "88888",
+ }
+ body = {
+ "log_id": str(uuid.uuid1()),
+ "version": "3.0",
+ "service_id": self.service_id,
+ "session_id": str(uuid.uuid1()),
+ "request": request,
+ }
+ try:
+ headers = {"Content-Type": "application/json"}
+ response = requests.post(url, json=body, headers=headers)
+ return json.loads(response.text)
+ except Exception:
+ return None
+
+ def getUnit2(self, query):
+ """
+ NLU 解析 version 2.0
+
+ :param query: 用户的指令字符串
+ :returns: UNIT 解析结果。如果解析失败,返回 None
+ """
+ url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
+ request = {"query": query, "user_id": str(get_mac())[:32]}
+ body = {
+ "log_id": str(uuid.uuid1()),
+ "version": "2.0",
+ "service_id": self.service_id,
+ "session_id": str(uuid.uuid1()),
+ "request": request,
+ }
+ try:
+ headers = {"Content-Type": "application/json"}
+ response = requests.post(url, json=body, headers=headers)
+ return json.loads(response.text)
+ except Exception:
+ return None
+
+ def getIntent(self, parsed):
+ """
+ 提取意图
+
+ :param parsed: UNIT 解析结果
+ :returns: 意图数组
+ """
+ if parsed and "result" in parsed and "response_list" in parsed["result"]:
+ try:
+ return parsed["result"]["response_list"][0]["schema"]["intent"]
+ except Exception as e:
+ logger.warning(e)
+ return ""
+ else:
+ return ""
+
+ def hasIntent(self, parsed, intent):
+ """
+ 判断是否包含某个意图
+
+ :param parsed: UNIT 解析结果
+ :param intent: 意图的名称
+ :returns: True: 包含; False: 不包含
+ """
+ if parsed and "result" in parsed and "response_list" in parsed["result"]:
+ response_list = parsed["result"]["response_list"]
+ for response in response_list:
+ if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
+ return True
+ return False
+ else:
+ return False
+
+ def getSlots(self, parsed, intent=""):
+ """
+ 提取某个意图的所有词槽
+
+ :param parsed: UNIT 解析结果
+ :param intent: 意图的名称
+ :returns: 词槽列表。你可以通过 name 属性筛选词槽,
+ 再通过 normalized_word 属性取出相应的值
+ """
+ if parsed and "result" in parsed and "response_list" in parsed["result"]:
+ response_list = parsed["result"]["response_list"]
+ if intent == "":
+ try:
+ return parsed["result"]["response_list"][0]["schema"]["slots"]
+ except Exception as e:
+ logger.warning(e)
+ return []
+ for response in response_list:
+ if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
+ return response["schema"]["slots"]
+ return []
+ else:
+ return []
+
+ def getSlotWords(self, parsed, intent, name):
+ """
+ 找出命中某个词槽的内容
+
+ :param parsed: UNIT 解析结果
+ :param intent: 意图的名称
+ :param name: 词槽名
+ :returns: 命中该词槽的值的列表。
+ """
+ slots = self.getSlots(parsed, intent)
+ words = []
+ for slot in slots:
+ if slot["name"] == name:
+ words.append(slot["normalized_word"])
+ return words
+
+ def getSayByConfidence(self, parsed):
+ """
+ 提取 UNIT 置信度最高的回复文本
+
+ :param parsed: UNIT 解析结果
+ :returns: UNIT 的回复文本
+ """
+ if parsed and "result" in parsed and "response_list" in parsed["result"]:
+ response_list = parsed["result"]["response_list"]
+ answer = {}
+ for response in response_list:
+ if (
+ "schema" in response
+ and "intent_confidence" in response["schema"]
+ and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
+ ):
+ answer = response
+ return answer["action_list"][0]["say"]
+ else:
+ return ""
+
+ def getSay(self, parsed, intent=""):
+ """
+ 提取 UNIT 的回复文本
+
+ :param parsed: UNIT 解析结果
+ :param intent: 意图的名称
+ :returns: UNIT 的回复文本
+ """
+ if parsed and "result" in parsed and "response_list" in parsed["result"]:
+ response_list = parsed["result"]["response_list"]
+ if intent == "":
+ try:
+ return response_list[0]["action_list"][0]["say"]
+ except Exception as e:
+ logger.warning(e)
+ return ""
+ for response in response_list:
+ if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
+ try:
+ return response["action_list"][0]["say"]
+ except Exception as e:
+ logger.warning(e)
+ return ""
+ return ""
+ else:
+ return ""
diff --git a/plugins/bdunit/config.json.template b/plugins/bdunit/config.json.template
new file mode 100644
index 0000000..c3bad56
--- /dev/null
+++ b/plugins/bdunit/config.json.template
@@ -0,0 +1,5 @@
+{
+ "service_id": "s...",
+ "api_key": "",
+ "secret_key": ""
+}
diff --git a/plugins/config.json.template b/plugins/config.json.template
new file mode 100644
index 0000000..95a59bc
--- /dev/null
+++ b/plugins/config.json.template
@@ -0,0 +1,44 @@
+{
+ "godcmd": {
+ "password": "",
+ "admin_users": []
+ },
+ "banwords": {
+ "action": "replace",
+ "reply_filter": true,
+ "reply_action": "ignore"
+ },
+ "tool": {
+ "tools": [
+ "python",
+ "url-get",
+ "terminal",
+ "meteo-weather"
+ ],
+ "kwargs": {
+ "top_k_results": 2,
+ "no_default": false,
+ "model_name": "gpt-3.5-turbo"
+ }
+ },
+ "linkai": {
+ "group_app_map": {
+ "测试群1": "default",
+ "测试群2": "Kv2fXJcH"
+ },
+ "midjourney": {
+ "enabled": true,
+ "auto_translate": true,
+ "img_proxy": true,
+ "max_tasks": 3,
+ "max_tasks_per_user": 1,
+ "use_image_create_prefix": true
+ },
+ "summary": {
+ "enabled": true,
+ "group_enabled": true,
+ "max_file_size": 5000,
+ "type": ["FILE", "SHARING"]
+ }
+ }
+}
diff --git a/plugins/dungeon/README.md b/plugins/dungeon/README.md
new file mode 100644
index 0000000..2c2e8cd
--- /dev/null
+++ b/plugins/dungeon/README.md
@@ -0,0 +1,4 @@
+玩地牢游戏的聊天插件,触发方法如下:
+
+- `$开始冒险 <背景故事>` - 以<背景故事>开始一个地牢游戏,不填写会使用默认背景故事。之后聊天中你的所有消息会帮助ai完善这个故事。
+- `$停止冒险` - 停止一个地牢游戏,回归正常的ai。
diff --git a/plugins/dungeon/__init__.py b/plugins/dungeon/__init__.py
new file mode 100644
index 0000000..6b10443
--- /dev/null
+++ b/plugins/dungeon/__init__.py
@@ -0,0 +1 @@
+from .dungeon import *
diff --git a/plugins/dungeon/dungeon.py b/plugins/dungeon/dungeon.py
new file mode 100644
index 0000000..dce62cd
--- /dev/null
+++ b/plugins/dungeon/dungeon.py
@@ -0,0 +1,106 @@
+# encoding:utf-8
+
+import plugins
+from bridge.bridge import Bridge
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common import const
+from common.expired_dict import ExpiredDict
+from common.log import logger
+from config import conf
+from plugins import *
+
+
+# https://github.com/bupticybee/ChineseAiDungeonChatGPT
+class StoryTeller:
+ def __init__(self, bot, sessionid, story):
+ self.bot = bot
+ self.sessionid = sessionid
+ bot.sessions.clear_session(sessionid)
+ self.first_interact = True
+ self.story = story
+
+ def reset(self):
+ self.bot.sessions.clear_session(self.sessionid)
+ self.first_interact = True
+
+ def action(self, user_action):
+ if user_action[-1] != "。":
+ user_action = user_action + "。"
+ if self.first_interact:
+ prompt = (
+ """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。
+ 开头是,"""
+ + self.story
+ + " "
+ + user_action
+ )
+ self.first_interact = False
+ else:
+ prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action
+ return prompt
+
+
+@plugins.register(
+ name="Dungeon",
+ desire_priority=0,
+ namecn="文字冒险",
+ desc="A plugin to play dungeon game",
+ version="1.0",
+ author="lanvent",
+)
+class Dungeon(Plugin):
+ def __init__(self):
+ super().__init__()
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+ logger.info("[Dungeon] inited")
+ # 目前没有设计session过期事件,这里先暂时使用过期字典
+ if conf().get("expires_in_seconds"):
+ self.games = ExpiredDict(conf().get("expires_in_seconds"))
+ else:
+ self.games = dict()
+
+ def on_handle_context(self, e_context: EventContext):
+ if e_context["context"].type != ContextType.TEXT:
+ return
+ bottype = Bridge().get_bot_type("chat")
+ if bottype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
+ return
+ bot = Bridge().get_bot("chat")
+ content = e_context["context"].content[:]
+ clist = e_context["context"].content.split(maxsplit=1)
+ sessionid = e_context["context"]["session_id"]
+ logger.debug("[Dungeon] on_handle_context. content: %s" % clist)
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ if clist[0] == f"{trigger_prefix}停止冒险":
+ if sessionid in self.games:
+ self.games[sessionid].reset()
+ del self.games[sessionid]
+ reply = Reply(ReplyType.INFO, "冒险结束!")
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games:
+ if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险":
+ if len(clist) > 1:
+ story = clist[1]
+ else:
+ story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
+ self.games[sessionid] = StoryTeller(bot, sessionid, story)
+ reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+ else:
+ prompt = self.games[sessionid].action(content)
+ e_context["context"].type = ContextType.TEXT
+ e_context["context"].content = prompt
+ e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑
+
+ def get_help_text(self, **kwargs):
+ help_text = "可以和机器人一起玩文字冒险游戏。\n"
+ if kwargs.get("verbose") != True:
+ return help_text
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n"
+ if kwargs.get("verbose") == True:
+ help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
+ return help_text
diff --git a/plugins/event.py b/plugins/event.py
new file mode 100644
index 0000000..719e6fc
--- /dev/null
+++ b/plugins/event.py
@@ -0,0 +1,55 @@
+# encoding:utf-8
+
+from enum import Enum
+
+
+class Event(Enum):
+ ON_RECEIVE_MESSAGE = 1 # 收到消息
+ """
+ e_context = { "channel": 消息channel, "context" : 本次消息的context}
+ """
+
+ ON_HANDLE_CONTEXT = 2 # 处理消息前
+ """
+ e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 }
+ """
+
+ ON_DECORATE_REPLY = 3 # 得到回复后准备装饰
+ """
+ e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
+ """
+
+ ON_SEND_REPLY = 4 # 发送回复前
+ """
+ e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
+ """
+
+ # AFTER_SEND_REPLY = 5 # 发送回复后
+
+
+class EventAction(Enum):
+ CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
+ BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
+ BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
+
+
+class EventContext:
+ def __init__(self, event, econtext=dict()):
+ self.event = event
+ self.econtext = econtext
+ self.action = EventAction.CONTINUE
+
+ def __getitem__(self, key):
+ return self.econtext[key]
+
+ def __setitem__(self, key, value):
+ self.econtext[key] = value
+
+ def __delitem__(self, key):
+ del self.econtext[key]
+
+ def is_pass(self):
+ return self.action == EventAction.BREAK_PASS
+
+ def is_break(self):
+ return self.action == EventAction.BREAK or self.action == EventAction.BREAK_PASS
diff --git a/plugins/finish/__init__.py b/plugins/finish/__init__.py
new file mode 100644
index 0000000..8c1cfd9
--- /dev/null
+++ b/plugins/finish/__init__.py
@@ -0,0 +1 @@
+from .finish import *
diff --git a/plugins/finish/finish.py b/plugins/finish/finish.py
new file mode 100644
index 0000000..a3c87ea
--- /dev/null
+++ b/plugins/finish/finish.py
@@ -0,0 +1,40 @@
+# encoding:utf-8
+
+import plugins
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+from plugins import *
+
+
+@plugins.register(
+ name="Finish",
+ desire_priority=-999,
+ hidden=True,
+ desc="A plugin that check unknown command",
+ version="1.0",
+ author="js00000",
+)
+class Finish(Plugin):
+ def __init__(self):
+ super().__init__()
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+ logger.info("[Finish] inited")
+
+ def on_handle_context(self, e_context: EventContext):
+ if e_context["context"].type != ContextType.TEXT:
+ return
+
+ content = e_context["context"].content
+ logger.debug("[Finish] on_handle_context. content: %s" % content)
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ if content.startswith(trigger_prefix):
+ reply = Reply()
+ reply.type = ReplyType.ERROR
+ reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n"
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+
+ def get_help_text(self, **kwargs):
+ return ""
diff --git a/plugins/godcmd/README.md b/plugins/godcmd/README.md
new file mode 100644
index 0000000..3b7ce1d
--- /dev/null
+++ b/plugins/godcmd/README.md
@@ -0,0 +1,18 @@
+## 插件说明
+
+指令插件
+
+## 插件使用
+
+将`config.json.template`复制为`config.json`,并修改其中`password`的值为口令。
+
+如果没有设置命令,在命令行日志中会打印出本次的临时口令,请注意观察,打印格式如下。
+
+```
+[INFO][2023-04-06 23:53:47][godcmd.py:165] - [Godcmd] 因未设置口令,本次的临时口令为0971。
+```
+
+在私聊中可使用`#auth`指令,输入口令进行管理员认证。更多详细指令请输入`#help`查看帮助文档:
+
+`#auth <口令>` - 管理员认证,仅可在私聊时认证。
+`#help` - 输出帮助文档,**是否是管理员**和是否是在群聊中会影响帮助文档的输出内容。
diff --git a/plugins/godcmd/__init__.py b/plugins/godcmd/__init__.py
new file mode 100644
index 0000000..0e26552
--- /dev/null
+++ b/plugins/godcmd/__init__.py
@@ -0,0 +1 @@
+from .godcmd import *
diff --git a/plugins/godcmd/config.json.template b/plugins/godcmd/config.json.template
new file mode 100644
index 0000000..ed021e0
--- /dev/null
+++ b/plugins/godcmd/config.json.template
@@ -0,0 +1,4 @@
+{
+ "password": "",
+ "admin_users": []
+}
diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py
new file mode 100644
index 0000000..a965a68
--- /dev/null
+++ b/plugins/godcmd/godcmd.py
@@ -0,0 +1,485 @@
+# encoding:utf-8
+
+import json
+import os
+import random
+import string
+import logging
+from typing import Tuple
+
+import bridge.bridge
+import plugins
+from bridge.bridge import Bridge
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common import const
+from config import conf, load_config, global_config
+from plugins import *
+
+# 定义指令集
+COMMANDS = {
+ "help": {
+ "alias": ["help", "帮助"],
+ "desc": "回复此帮助",
+ },
+ "helpp": {
+ "alias": ["help", "帮助"], # 与help指令共用别名,根据参数数量区分
+ "args": ["插件名"],
+ "desc": "回复指定插件的详细帮助",
+ },
+ "auth": {
+ "alias": ["auth", "认证"],
+ "args": ["口令"],
+ "desc": "管理员认证",
+ },
+ "model": {
+ "alias": ["model", "模型"],
+ "desc": "查看和设置全局模型",
+ },
+ "set_openai_api_key": {
+ "alias": ["set_openai_api_key"],
+ "args": ["api_key"],
+ "desc": "设置你的OpenAI私有api_key",
+ },
+ "reset_openai_api_key": {
+ "alias": ["reset_openai_api_key"],
+ "desc": "重置为默认的api_key",
+ },
+ "set_gpt_model": {
+ "alias": ["set_gpt_model"],
+ "desc": "设置你的私有模型",
+ },
+ "reset_gpt_model": {
+ "alias": ["reset_gpt_model"],
+ "desc": "重置你的私有模型",
+ },
+ "gpt_model": {
+ "alias": ["gpt_model"],
+ "desc": "查询你使用的模型",
+ },
+ "id": {
+ "alias": ["id", "用户"],
+ "desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员
+ },
+ "reset": {
+ "alias": ["reset", "重置会话"],
+ "desc": "重置会话",
+ },
+}
+
+ADMIN_COMMANDS = {
+ "resume": {
+ "alias": ["resume", "恢复服务"],
+ "desc": "恢复服务",
+ },
+ "stop": {
+ "alias": ["stop", "暂停服务"],
+ "desc": "暂停服务",
+ },
+ "reconf": {
+ "alias": ["reconf", "重载配置"],
+ "desc": "重载配置(不包含插件配置)",
+ },
+ "resetall": {
+ "alias": ["resetall", "重置所有会话"],
+ "desc": "重置所有会话",
+ },
+ "scanp": {
+ "alias": ["scanp", "扫描插件"],
+ "desc": "扫描插件目录是否有新插件",
+ },
+ "plist": {
+ "alias": ["plist", "插件"],
+ "desc": "打印当前插件列表",
+ },
+ "setpri": {
+ "alias": ["setpri", "设置插件优先级"],
+ "args": ["插件名", "优先级"],
+ "desc": "设置指定插件的优先级,越大越优先",
+ },
+ "reloadp": {
+ "alias": ["reloadp", "重载插件"],
+ "args": ["插件名"],
+ "desc": "重载指定插件配置",
+ },
+ "enablep": {
+ "alias": ["enablep", "启用插件"],
+ "args": ["插件名"],
+ "desc": "启用指定插件",
+ },
+ "disablep": {
+ "alias": ["disablep", "禁用插件"],
+ "args": ["插件名"],
+ "desc": "禁用指定插件",
+ },
+ "installp": {
+ "alias": ["installp", "安装插件"],
+ "args": ["仓库地址或插件名"],
+ "desc": "安装指定插件",
+ },
+ "uninstallp": {
+ "alias": ["uninstallp", "卸载插件"],
+ "args": ["插件名"],
+ "desc": "卸载指定插件",
+ },
+ "updatep": {
+ "alias": ["updatep", "更新插件"],
+ "args": ["插件名"],
+ "desc": "更新指定插件",
+ },
+ "debug": {
+ "alias": ["debug", "调试模式", "DEBUG"],
+ "desc": "开启机器调试日志",
+ },
+}
+
+
+# 定义帮助函数
+def get_help_text(isadmin, isgroup):
+ help_text = "通用指令\n"
+ for cmd, info in COMMANDS.items():
+ if cmd in ["auth", "set_openai_api_key", "reset_openai_api_key", "set_gpt_model", "reset_gpt_model", "gpt_model"]: # 不显示帮助指令
+ continue
+ if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]:
+ continue
+ alias = ["#" + a for a in info["alias"][:1]]
+ help_text += f"{','.join(alias)} "
+ if "args" in info:
+ args = [a for a in info["args"]]
+ help_text += f"{' '.join(args)}"
+ help_text += f": {info['desc']}\n"
+
+ # 插件指令
+ plugins = PluginManager().list_plugins()
+ help_text += "\n可用插件"
+ for plugin in plugins:
+ if plugins[plugin].enabled and not plugins[plugin].hidden:
+ namecn = plugins[plugin].namecn
+ help_text += "\n%s:" % namecn
+ help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
+
+ if ADMIN_COMMANDS and isadmin:
+ help_text += "\n\n管理员指令:\n"
+ for cmd, info in ADMIN_COMMANDS.items():
+ alias = ["#" + a for a in info["alias"][:1]]
+ help_text += f"{','.join(alias)} "
+ if "args" in info:
+ args = [a for a in info["args"]]
+ help_text += f"{' '.join(args)}"
+ help_text += f": {info['desc']}\n"
+ return help_text
+
+
+@plugins.register(
+ name="Godcmd",
+ desire_priority=999,
+ hidden=True,
+ desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证",
+ version="1.0",
+ author="lanvent",
+)
+class Godcmd(Plugin):
+ def __init__(self):
+ super().__init__()
+
+ config_path = os.path.join(os.path.dirname(__file__), "config.json")
+ gconf = super().load_config()
+ if not gconf:
+ if not os.path.exists(config_path):
+ gconf = {"password": "", "admin_users": []}
+ with open(config_path, "w") as f:
+ json.dump(gconf, f, indent=4)
+ if gconf["password"] == "":
+ self.temp_password = "".join(random.sample(string.digits, 4))
+ logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。" % self.temp_password)
+ else:
+ self.temp_password = None
+ custom_commands = conf().get("clear_memory_commands", [])
+ for custom_command in custom_commands:
+ if custom_command and custom_command.startswith("#"):
+ custom_command = custom_command[1:]
+ if custom_command and custom_command not in COMMANDS["reset"]["alias"]:
+ COMMANDS["reset"]["alias"].append(custom_command)
+
+ self.password = gconf["password"]
+ self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
+ global_config["admin_users"] = self.admin_users
+ self.isrunning = True # 机器人是否运行中
+
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+ logger.info("[Godcmd] inited")
+
+ def on_handle_context(self, e_context: EventContext):
+ context_type = e_context["context"].type
+ if context_type != ContextType.TEXT:
+ if not self.isrunning:
+ e_context.action = EventAction.BREAK_PASS
+ return
+
+ content = e_context["context"].content
+ logger.debug("[Godcmd] on_handle_context. content: %s" % content)
+ if content.startswith("#"):
+ if len(content) == 1:
+ reply = Reply()
+ reply.type = ReplyType.ERROR
+ reply.content = f"空指令,输入#help查看指令列表\n"
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+ # msg = e_context['context']['msg']
+ channel = e_context["channel"]
+ user = e_context["context"]["receiver"]
+ session_id = e_context["context"]["session_id"]
+ isgroup = e_context["context"].get("isgroup", False)
+ bottype = Bridge().get_bot_type("chat")
+ bot = Bridge().get_bot("chat")
+ # 将命令和参数分割
+ command_parts = content[1:].strip().split()
+ cmd = command_parts[0]
+ args = command_parts[1:]
+ isadmin = False
+ if user in self.admin_users:
+ isadmin = True
+ ok = False
+ result = "string"
+ if any(cmd in info["alias"] for info in COMMANDS.values()):
+ cmd = next(c for c, info in COMMANDS.items() if cmd in info["alias"])
+ if cmd == "auth":
+ ok, result = self.authenticate(user, args, isadmin, isgroup)
+ elif cmd == "help" or cmd == "helpp":
+ if len(args) == 0:
+ ok, result = True, get_help_text(isadmin, isgroup)
+ else:
+ # This can replace the helpp command
+ plugins = PluginManager().list_plugins()
+ query_name = args[0].upper()
+ # search name and namecn
+ for name, plugincls in plugins.items():
+ if not plugincls.enabled:
+ continue
+ if query_name == name or query_name == plugincls.namecn:
+ ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
+ break
+ if not ok:
+ result = "插件不存在或未启用"
+ elif cmd == "model":
+ if not isadmin and not self.is_admin_in_group(e_context["context"]):
+ ok, result = False, "需要管理员权限执行"
+ elif len(args) == 0:
+ model = conf().get("model") or const.GPT35
+ ok, result = True, "当前模型为: " + str(model)
+ elif len(args) == 1:
+ if args[0] not in const.MODEL_LIST:
+ ok, result = False, "模型名称不存在"
+ else:
+ conf()["model"] = self.model_mapping(args[0])
+ Bridge().reset_bot()
+ model = conf().get("model") or const.GPT35
+ ok, result = True, "模型设置为: " + str(model)
+ elif cmd == "id":
+ ok, result = True, user
+ elif cmd == "set_openai_api_key":
+ if len(args) == 1:
+ user_data = conf().get_user_data(user)
+ user_data["openai_api_key"] = args[0]
+ ok, result = True, "你的OpenAI私有api_key已设置为" + args[0]
+ else:
+ ok, result = False, "请提供一个api_key"
+ elif cmd == "reset_openai_api_key":
+ try:
+ user_data = conf().get_user_data(user)
+ user_data.pop("openai_api_key")
+ ok, result = True, "你的OpenAI私有api_key已清除"
+ except Exception as e:
+ ok, result = False, "你没有设置私有api_key"
+ elif cmd == "set_gpt_model":
+ if len(args) == 1:
+ user_data = conf().get_user_data(user)
+ user_data["gpt_model"] = args[0]
+ ok, result = True, "你的GPT模型已设置为" + args[0]
+ else:
+ ok, result = False, "请提供一个GPT模型"
+ elif cmd == "gpt_model":
+ user_data = conf().get_user_data(user)
+ model = conf().get("model")
+ if "gpt_model" in user_data:
+ model = user_data["gpt_model"]
+ ok, result = True, "你的GPT模型为" + str(model)
+ elif cmd == "reset_gpt_model":
+ try:
+ user_data = conf().get_user_data(user)
+ user_data.pop("gpt_model")
+ ok, result = True, "你的GPT模型已重置"
+ except Exception as e:
+ ok, result = False, "你没有设置私有GPT模型"
+ elif cmd == "reset":
+ if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]:
+ bot.sessions.clear_session(session_id)
+ if Bridge().chat_bots.get(bottype):
+ Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)
+ channel.cancel_session(session_id)
+ ok, result = True, "会话已重置"
+ else:
+ ok, result = False, "当前对话机器人不支持重置会话"
+ logger.debug("[Godcmd] command: %s by %s" % (cmd, user))
+ elif any(cmd in info["alias"] for info in ADMIN_COMMANDS.values()):
+ if isadmin:
+ if isgroup:
+ ok, result = False, "群聊不可执行管理员指令"
+ else:
+ cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"])
+ if cmd == "stop":
+ self.isrunning = False
+ ok, result = True, "服务已暂停"
+ elif cmd == "resume":
+ self.isrunning = True
+ ok, result = True, "服务已恢复"
+ elif cmd == "reconf":
+ load_config()
+ ok, result = True, "配置已重载"
+ elif cmd == "resetall":
+ if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
+ const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]:
+ channel.cancel_all_session()
+ bot.sessions.clear_all_session()
+ ok, result = True, "重置所有会话成功"
+ else:
+ ok, result = False, "当前对话机器人不支持重置会话"
+ elif cmd == "debug":
+ if logger.getEffectiveLevel() == logging.DEBUG: # 判断当前日志模式是否DEBUG
+ logger.setLevel(logging.INFO)
+ ok, result = True, "DEBUG模式已关闭"
+ else:
+ logger.setLevel(logging.DEBUG)
+ ok, result = True, "DEBUG模式已开启"
+ elif cmd == "plist":
+ plugins = PluginManager().list_plugins()
+ ok = True
+ result = "插件列表:\n"
+ for name, plugincls in plugins.items():
+ result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - "
+ if plugincls.enabled:
+ result += "已启用\n"
+ else:
+ result += "未启用\n"
+ elif cmd == "scanp":
+ new_plugins = PluginManager().scan_plugins()
+ ok, result = True, "插件扫描完成"
+ PluginManager().activate_plugins()
+ if len(new_plugins) > 0:
+ result += "\n发现新插件:\n"
+ result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
+ else:
+ result += ", 未发现新插件"
+ elif cmd == "setpri":
+ if len(args) != 2:
+ ok, result = False, "请提供插件名和优先级"
+ else:
+ ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
+ if ok:
+ result = "插件" + args[0] + "优先级已设置为" + args[1]
+ else:
+ result = "插件不存在"
+ elif cmd == "reloadp":
+ if len(args) != 1:
+ ok, result = False, "请提供插件名"
+ else:
+ ok = PluginManager().reload_plugin(args[0])
+ if ok:
+ result = "插件配置已重载"
+ else:
+ result = "插件不存在"
+ elif cmd == "enablep":
+ if len(args) != 1:
+ ok, result = False, "请提供插件名"
+ else:
+ ok, result = PluginManager().enable_plugin(args[0])
+ elif cmd == "disablep":
+ if len(args) != 1:
+ ok, result = False, "请提供插件名"
+ else:
+ ok = PluginManager().disable_plugin(args[0])
+ if ok:
+ result = "插件已禁用"
+ else:
+ result = "插件不存在"
+ elif cmd == "installp":
+ if len(args) != 1:
+ ok, result = False, "请提供插件名或.git结尾的仓库地址"
+ else:
+ ok, result = PluginManager().install_plugin(args[0])
+ elif cmd == "uninstallp":
+ if len(args) != 1:
+ ok, result = False, "请提供插件名"
+ else:
+ ok, result = PluginManager().uninstall_plugin(args[0])
+ elif cmd == "updatep":
+ if len(args) != 1:
+ ok, result = False, "请提供插件名"
+ else:
+ ok, result = PluginManager().update_plugin(args[0])
+ logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user))
+ else:
+ ok, result = False, "需要管理员权限才能执行该指令"
+ else:
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交
+ return
+ ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n"
+
+ reply = Reply()
+ if ok:
+ reply.type = ReplyType.INFO
+ else:
+ reply.type = ReplyType.ERROR
+ reply.content = result
+ e_context["reply"] = reply
+
+ e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+ elif not self.isrunning:
+ e_context.action = EventAction.BREAK_PASS
+
+ def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool, str]:
+ if isgroup:
+ return False, "请勿在群聊中认证"
+
+ if isadmin:
+ return False, "管理员账号无需认证"
+
+ if len(args) != 1:
+ return False, "请提供口令"
+
+ password = args[0]
+ if password == self.password:
+ self.admin_users.append(userid)
+ global_config["admin_users"].append(userid)
+ return True, "认证成功"
+ elif password == self.temp_password:
+ self.admin_users.append(userid)
+ global_config["admin_users"].append(userid)
+ return True, "认证成功,请尽快设置口令"
+ else:
+ return False, "认证失败"
+
+ def get_help_text(self, isadmin=False, isgroup=False, **kwargs):
+ return get_help_text(isadmin, isgroup)
+
+
+ def is_admin_in_group(self, context):
+ if context["isgroup"]:
+ return context.kwargs.get("msg").actual_user_id in global_config["admin_users"]
+ return False
+
+
+ def model_mapping(self, model) -> str:
+ if model == "gpt-4-turbo":
+ return const.GPT4_TURBO_PREVIEW
+ return model
+
+ def reload(self):
+ gconf = plugin_config[self.name]
+ if gconf:
+ if gconf.get("password"):
+ self.password = gconf["password"]
+ if gconf.get("admin_users"):
+ self.admin_users = gconf["admin_users"]
diff --git a/plugins/hello/__init__.py b/plugins/hello/__init__.py
new file mode 100644
index 0000000..d9b15a1
--- /dev/null
+++ b/plugins/hello/__init__.py
@@ -0,0 +1 @@
+from .hello import *
diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py
new file mode 100644
index 0000000..e86c609
--- /dev/null
+++ b/plugins/hello/hello.py
@@ -0,0 +1,98 @@
+# encoding:utf-8
+
+import plugins
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from channel.chat_message import ChatMessage
+from common.log import logger
+from plugins import *
+from config import conf
+
+
+@plugins.register(
+ name="Hello",
+ desire_priority=-1,
+ hidden=True,
+ desc="A simple plugin that says hello",
+ version="0.1",
+ author="lanvent",
+)
+class Hello(Plugin):
+ def __init__(self):
+ super().__init__()
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+ logger.info("[Hello] inited")
+ self.config = super().load_config()
+
+ def on_handle_context(self, e_context: EventContext):
+ if e_context["context"].type not in [
+ ContextType.TEXT,
+ ContextType.JOIN_GROUP,
+ ContextType.PATPAT,
+ ContextType.EXIT_GROUP
+ ]:
+ return
+ if e_context["context"].type == ContextType.JOIN_GROUP:
+ if "group_welcome_msg" in conf():
+ reply = Reply()
+ reply.type = ReplyType.TEXT
+ reply.content = conf().get("group_welcome_msg", "")
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+ return
+ e_context["context"].type = ContextType.TEXT
+ msg: ChatMessage = e_context["context"]["msg"]
+ e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
+ e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
+ if not self.config or not self.config.get("use_character_desc"):
+ e_context["context"]["generate_breaked_by"] = EventAction.BREAK
+ return
+
+ if e_context["context"].type == ContextType.EXIT_GROUP:
+ if conf().get("group_chat_exit_group"):
+ e_context["context"].type = ContextType.TEXT
+ msg: ChatMessage = e_context["context"]["msg"]
+ e_context["context"].content = f'请你随机使用一种风格跟其他群用户说他违反规则"{msg.actual_user_nickname}"退出群聊。'
+ e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
+ return
+ e_context.action = EventAction.BREAK
+ return
+
+ if e_context["context"].type == ContextType.PATPAT:
+ e_context["context"].type = ContextType.TEXT
+ msg: ChatMessage = e_context["context"]["msg"]
+ e_context["context"].content = f"请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。"
+ e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
+ if not self.config or not self.config.get("use_character_desc"):
+ e_context["context"]["generate_breaked_by"] = EventAction.BREAK
+ return
+
+ content = e_context["context"].content
+ logger.debug("[Hello] on_handle_context. content: %s" % content)
+ if content == "Hello":
+ reply = Reply()
+ reply.type = ReplyType.TEXT
+ msg: ChatMessage = e_context["context"]["msg"]
+ if e_context["context"]["isgroup"]:
+ reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
+ else:
+ reply.content = f"Hello, {msg.from_user_nickname}"
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+
+ if content == "Hi":
+ reply = Reply()
+ reply.type = ReplyType.TEXT
+ reply.content = "Hi"
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply
+
+ if content == "End":
+ # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World"
+ e_context["context"].type = ContextType.IMAGE_CREATE
+ content = "The World"
+ e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
+
+ def get_help_text(self, **kwargs):
+ help_text = "输入Hello,我会回复你的名字\n输入End,我会回复你世界的图片\n"
+ return help_text
diff --git a/plugins/keyword/README.md b/plugins/keyword/README.md
new file mode 100644
index 0000000..4678f68
--- /dev/null
+++ b/plugins/keyword/README.md
@@ -0,0 +1,13 @@
+# 目的
+关键字匹配并回复
+
+# 试用场景
+目前是在微信公众号下面使用过。
+
+# 使用步骤
+1. 复制 `config.json.template` 为 `config.json`
+2. 在关键字 `keyword` 新增需要关键字匹配的内容
+3. 重启程序做验证
+
+# 验证结果
+![结果](test-keyword.png)
\ No newline at end of file
diff --git a/plugins/keyword/__init__.py b/plugins/keyword/__init__.py
new file mode 100644
index 0000000..b860b69
--- /dev/null
+++ b/plugins/keyword/__init__.py
@@ -0,0 +1 @@
+from .keyword import *
diff --git a/plugins/keyword/config.json.template b/plugins/keyword/config.json.template
new file mode 100644
index 0000000..dbd5efe
--- /dev/null
+++ b/plugins/keyword/config.json.template
@@ -0,0 +1,5 @@
+{
+ "keyword": {
+ "关键字匹配": "测试成功"
+ }
+}
diff --git a/plugins/keyword/keyword.py b/plugins/keyword/keyword.py
new file mode 100644
index 0000000..87cd054
--- /dev/null
+++ b/plugins/keyword/keyword.py
@@ -0,0 +1,96 @@
+# encoding:utf-8
+
+import json
+import os
+import requests
+import plugins
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from plugins import *
+
+
+@plugins.register(
+ name="Keyword",
+ desire_priority=900,
+ hidden=True,
+ desc="关键词匹配过滤",
+ version="0.1",
+ author="fengyege.top",
+)
+class Keyword(Plugin):
+ def __init__(self):
+ super().__init__()
+ try:
+ curdir = os.path.dirname(__file__)
+ config_path = os.path.join(curdir, "config.json")
+ conf = None
+ if not os.path.exists(config_path):
+ logger.debug(f"[keyword]不存在配置文件{config_path}")
+ conf = {"keyword": {}}
+ with open(config_path, "w", encoding="utf-8") as f:
+ json.dump(conf, f, indent=4)
+ else:
+ logger.debug(f"[keyword]加载配置文件{config_path}")
+ with open(config_path, "r", encoding="utf-8") as f:
+ conf = json.load(f)
+ # 加载关键词
+ self.keyword = conf["keyword"]
+
+ logger.info("[keyword] {}".format(self.keyword))
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+ logger.info("[keyword] inited.")
+ except Exception as e:
+ logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .")
+ raise e
+
+ def on_handle_context(self, e_context: EventContext):
+ if e_context["context"].type != ContextType.TEXT:
+ return
+
+ content = e_context["context"].content.strip()
+ logger.debug("[keyword] on_handle_context. content: %s" % content)
+ if content in self.keyword:
+ logger.info(f"[keyword] 匹配到关键字【{content}】")
+ reply_text = self.keyword[content]
+
+ # 判断匹配内容的类型
+ if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".img"]):
+ # 如果是以 http:// 或 https:// 开头,且".jpg", ".jpeg", ".png", ".gif", ".img"结尾,则认为是图片 URL。
+ reply = Reply()
+ reply.type = ReplyType.IMAGE_URL
+ reply.content = reply_text
+
+ elif (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".pdf", ".doc", ".docx", ".xls", "xlsx",".zip", ".rar"]):
+ # 如果是以 http:// 或 https:// 开头,且".pdf", ".doc", ".docx", ".xls", "xlsx",".zip", ".rar"结尾,则下载文件到tmp目录并发送给用户
+ file_path = "tmp"
+ if not os.path.exists(file_path):
+ os.makedirs(file_path)
+ file_name = reply_text.split("/")[-1] # 获取文件名
+ file_path = os.path.join(file_path, file_name)
+ response = requests.get(reply_text)
+ with open(file_path, "wb") as f:
+ f.write(response.content)
+ #channel/wechat/wechat_channel.py和channel/wechat_channel.py中缺少ReplyType.FILE类型。
+ reply = Reply()
+ reply.type = ReplyType.FILE
+ reply.content = file_path
+
+ elif (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".mp4"]):
+ # 如果是以 http:// 或 https:// 开头,且".mp4"结尾,则下载视频到tmp目录并发送给用户
+ reply = Reply()
+ reply.type = ReplyType.VIDEO_URL
+ reply.content = reply_text
+
+ else:
+ # 否则认为是普通文本
+ reply = Reply()
+ reply.type = ReplyType.TEXT
+ reply.content = reply_text
+
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+
+ def get_help_text(self, **kwargs):
+ help_text = "关键词过滤"
+ return help_text
diff --git a/plugins/keyword/test-keyword.png b/plugins/keyword/test-keyword.png
new file mode 100644
index 0000000..0f17ae8
Binary files /dev/null and b/plugins/keyword/test-keyword.png differ
diff --git a/plugins/linkai/README.md b/plugins/linkai/README.md
new file mode 100644
index 0000000..2ac80b1
--- /dev/null
+++ b/plugins/linkai/README.md
@@ -0,0 +1,109 @@
+## 插件说明
+
+基于 LinkAI 提供的知识库、Midjourney绘画、文档对话等能力对机器人的功能进行增强。平台地址: https://link-ai.tech/console
+
+## 插件配置
+
+将 `plugins/linkai` 目录下的 `config.json.template` 配置模板复制为最终生效的 `config.json`。 (如果未配置则会默认使用`config.json.template`模板中配置,但功能默认关闭,需要可通过指令进行开启)。
+
+以下是插件配置项说明:
+
+```bash
+{
+ "group_app_map": { # 群聊 和 应用编码 的映射关系
+ "测试群名称1": "default", # 表示在名称为 "测试群名称1" 的群聊中将使用app_code 为 default 的应用
+ "测试群名称2": "Kv2fXJcH"
+ },
+ "midjourney": {
+ "enabled": true, # midjourney 绘画开关
+ "auto_translate": true, # 是否自动将提示词翻译为英文
+ "img_proxy": true, # 是否对生成的图片使用代理,如果你是国外服务器,将这一项设置为false会获得更快的生成速度
+ "max_tasks": 3, # 支持同时提交的总任务个数
+ "max_tasks_per_user": 1, # 支持单个用户同时提交的任务个数
+ "use_image_create_prefix": true # 是否使用全局的绘画触发词,如果开启将同时支持由`config.json`中的 image_create_prefix 配置触发
+ },
+ "summary": {
+ "enabled": true, # 文档总结和对话功能开关
+ "group_enabled": true, # 是否支持群聊开启
+ "max_file_size": 5000, # 文件的大小限制,单位KB,默认为5M,超过该大小直接忽略
+ "type": ["FILE", "SHARING", "IMAGE"] # 支持总结的类型,分别表示 文件、分享链接、图片,其中文件和链接默认打开,图片默认关闭
+ }
+}
+```
+
+根目录 `config.json` 中配置,`API_KEY` 在 [控制台](https://link-ai.tech/console/interface) 中创建并复制过来:
+
+```bash
+"linkai_api_key": "Link_xxxxxxxxx"
+```
+
+注意:
+
+ - 配置项中 `group_app_map` 部分是用于映射群聊与LinkAI平台上的应用, `midjourney` 部分是 mj 画图的配置,`summary` 部分是文档总结及对话功能的配置。三部分的配置相互独立,可按需开启
+ - 实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释
+ - 如果是`docker`部署,可通过映射 `plugins/config.json` 到容器中来完成插件配置,参考[文档](https://github.com/zhayujie/chatgpt-on-wechat#3-%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8)
+
+## 插件使用
+
+> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;而midjourney绘画 和 summary文档总结对话功能则只需填写 `linkai_api_key` 配置,`use_linkai` 无论是否关闭均可使用。具体可参考 [详细文档](https://link-ai.tech/platform/link-app/wechat)。
+
+完成配置后运行项目,会自动运行插件,输入 `#help linkai` 可查看插件功能。
+
+### 1.知识库管理功能
+
+提供在不同群聊使用不同应用的功能。可以在上述 `group_app_map` 配置中固定映射关系,也可以通过指令在群中快速完成切换。
+
+应用切换指令需要首先完成管理员 (`godcmd`) 插件的认证,然后按以下格式输入:
+
+`$linkai app {app_code}`
+
+例如输入 `$linkai app Kv2fXJcH`,即将当前群聊与 app_code为 Kv2fXJcH 的应用绑定。
+
+另外,还可以通过 `$linkai close` 来一键关闭linkai对话,此时就会使用默认的openai接口;同理,发送 `$linkai open` 可以再次开启。
+
+### 2.Midjourney绘画功能
+
+若未配置 `plugins/linkai/config.json`,默认会关闭画图功能,直接使用 `$mj open` 可基于默认配置直接使用mj画图。
+
+指令格式:
+
+```
+ - 图片生成: $mj 描述词1, 描述词2..
+ - 图片放大: $mju 图片ID 图片序号
+ - 图片变换: $mjv 图片ID 图片序号
+ - 重置: $mjr 图片ID
+```
+
+例如:
+
+```
+"$mj a little cat, white --ar 9:16"
+"$mju 1105592717188272288 2"
+"$mjv 11055927171882 2"
+"$mjr 11055927171882"
+```
+
+注意事项:
+1. 使用 `$mj open` 和 `$mj close` 指令可以快速打开和关闭绘图功能
+2. 海外环境部署请将 `img_proxy` 设置为 `false`
+3. 开启 `use_image_create_prefix` 配置后可直接复用全局画图触发词,以"画"开头便可以生成图片。
+4. 提示词内容中包含敏感词或者参数格式错误可能导致绘画失败,生成失败不消耗积分
+5. 若未收到图片可能有两种可能,一种是收到了图片但微信发送失败,可以在后台日志查看有没有获取到图片url,一般原因是受到了wx限制,可以稍后重试或更换账号尝试;另一种情况是图片提示词存在疑似违规,mj不会直接提示错误但会在画图后删掉原图导致程序无法获取,这种情况不消耗积分。
+
+### 3.文档总结对话功能
+
+#### 配置
+
+该功能依赖 LinkAI的知识库及对话功能,需要在项目根目录的config.json中设置 `linkai_api_key`, 同时根据上述插件配置说明,在插件config.json添加 `summary` 部分的配置,设置 `enabled` 为 true。
+
+如果不想创建 `plugins/linkai/config.json` 配置,可以直接通过 `$linkai sum open` 指令开启该功能。
+
+#### 使用
+
+功能开启后,向机器人发送 **文件**、 **分享链接卡片**、**图片** 即可生成摘要,进一步可以与文件或链接的内容进行多轮对话。如果需要关闭某种类型的内容总结,设置 `summary`配置中的type字段即可。
+
+#### 限制
+
+ 1. 文件目前 支持 `txt`, `docx`, `pdf`, `md`, `csv`格式,文件大小由 `max_file_size` 限制,最大不超过15M,文件字数最多可支持百万字的文件。但不建议上传字数过多的文件,一是token消耗过大,二是摘要很难覆盖到全部内容,只能通过多轮对话来了解细节。
+ 2. 分享链接 目前仅支持 公众号文章,后续会支持更多文章类型及视频链接等
+ 3. 总结及对话的 费用与 LinkAI 3.5-4K 模型的计费方式相同,按文档内容的tokens进行计算
diff --git a/plugins/linkai/__init__.py b/plugins/linkai/__init__.py
new file mode 100644
index 0000000..e7414be
--- /dev/null
+++ b/plugins/linkai/__init__.py
@@ -0,0 +1 @@
+from .linkai import *
diff --git a/plugins/linkai/config.json.template b/plugins/linkai/config.json.template
new file mode 100644
index 0000000..547b8ef
--- /dev/null
+++ b/plugins/linkai/config.json.template
@@ -0,0 +1,20 @@
+{
+ "group_app_map": {
+ "测试群名1": "default",
+ "测试群名2": "Kv2fXJcH"
+ },
+ "midjourney": {
+ "enabled": true,
+ "auto_translate": true,
+ "img_proxy": true,
+ "max_tasks": 3,
+ "max_tasks_per_user": 1,
+ "use_image_create_prefix": true
+ },
+ "summary": {
+ "enabled": true,
+ "group_enabled": true,
+ "max_file_size": 5000,
+ "type": ["FILE", "SHARING"]
+ }
+}
diff --git a/plugins/linkai/linkai.py b/plugins/linkai/linkai.py
new file mode 100644
index 0000000..7978743
--- /dev/null
+++ b/plugins/linkai/linkai.py
@@ -0,0 +1,287 @@
+import plugins
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from plugins import *
+from .midjourney import MJBot
+from .summary import LinkSummary
+from bridge import bridge
+from common.expired_dict import ExpiredDict
+from common import const
+import os
+from .utils import Util
+
+@plugins.register(
+ name="linkai",
+ desc="A plugin that supports knowledge base and midjourney drawing.",
+ version="0.1.0",
+ author="https://link-ai.tech",
+ desire_priority=99
+)
+class LinkAI(Plugin):
+ def __init__(self):
+ super().__init__()
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+ self.config = super().load_config()
+ if not self.config:
+ # 未加载到配置,使用模板中的配置
+ self.config = self._load_config_template()
+ if self.config:
+ self.mj_bot = MJBot(self.config.get("midjourney"))
+ self.sum_config = {}
+ if self.config:
+ self.sum_config = self.config.get("summary")
+ logger.info(f"[LinkAI] inited, config={self.config}")
+
+
+ def on_handle_context(self, e_context: EventContext):
+ """
+ 消息处理逻辑
+ :param e_context: 消息上下文
+ """
+ if not self.config:
+ return
+
+ context = e_context['context']
+ if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE, ContextType.FILE, ContextType.SHARING]:
+ # filter content no need solve
+ return
+
+ if context.type in [ContextType.FILE, ContextType.IMAGE] and self._is_summary_open(context):
+ # 文件处理
+ context.get("msg").prepare()
+ file_path = context.content
+ if not LinkSummary().check_file(file_path, self.sum_config):
+ return
+ if context.type != ContextType.IMAGE:
+ _send_info(e_context, "正在为你加速生成摘要,请稍后")
+ res = LinkSummary().summary_file(file_path)
+ if not res:
+ if context.type != ContextType.IMAGE:
+ _set_reply_text("因为神秘力量无法获取内容,请稍后再试吧", e_context, level=ReplyType.TEXT)
+ return
+ summary_text = res.get("summary")
+ if context.type != ContextType.IMAGE:
+ USER_FILE_MAP[_find_user_id(context) + "-sum_id"] = res.get("summary_id")
+ summary_text += "\n\n💬 发送 \"开启对话\" 可以开启与文件内容的对话"
+ _set_reply_text(summary_text, e_context, level=ReplyType.TEXT)
+ os.remove(file_path)
+ return
+
+ if (context.type == ContextType.SHARING and self._is_summary_open(context)) or \
+ (context.type == ContextType.TEXT and LinkSummary().check_url(context.content)):
+ if not LinkSummary().check_url(context.content):
+ return
+ _send_info(e_context, "正在为你加速生成摘要,请稍后")
+ res = LinkSummary().summary_url(context.content)
+ if not res:
+ _set_reply_text("因为神秘力量无法获取文章内容,请稍后再试吧~", e_context, level=ReplyType.TEXT)
+ return
+ _set_reply_text(res.get("summary") + "\n\n💬 发送 \"开启对话\" 可以开启与文章内容的对话", e_context, level=ReplyType.TEXT)
+ USER_FILE_MAP[_find_user_id(context) + "-sum_id"] = res.get("summary_id")
+ return
+
+ mj_type = self.mj_bot.judge_mj_task_type(e_context)
+ if mj_type:
+ # MJ作图任务处理
+ self.mj_bot.process_mj_task(mj_type, e_context)
+ return
+
+ if context.content.startswith(f"{_get_trigger_prefix()}linkai"):
+ # 应用管理功能
+ self._process_admin_cmd(e_context)
+ return
+
+ if context.type == ContextType.TEXT and context.content == "开启对话" and _find_sum_id(context):
+ # 文本对话
+ _send_info(e_context, "正在为你开启对话,请稍后")
+ res = LinkSummary().summary_chat(_find_sum_id(context))
+ if not res:
+ _set_reply_text("开启对话失败,请稍后再试吧", e_context)
+ return
+ USER_FILE_MAP[_find_user_id(context) + "-file_id"] = res.get("file_id")
+ _set_reply_text("💡你可以问我关于这篇文章的任何问题,例如:\n\n" + res.get("questions") + "\n\n发送 \"退出对话\" 可以关闭与文章的对话", e_context, level=ReplyType.TEXT)
+ return
+
+ if context.type == ContextType.TEXT and context.content == "退出对话" and _find_file_id(context):
+ del USER_FILE_MAP[_find_user_id(context) + "-file_id"]
+ bot = bridge.Bridge().find_chat_bot(const.LINKAI)
+ bot.sessions.clear_session(context["session_id"])
+ _set_reply_text("对话已退出", e_context, level=ReplyType.TEXT)
+ return
+
+ if context.type == ContextType.TEXT and _find_file_id(context):
+ bot = bridge.Bridge().find_chat_bot(const.LINKAI)
+ context.kwargs["file_id"] = _find_file_id(context)
+ reply = bot.reply(context.content, context)
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+
+
+ if self._is_chat_task(e_context):
+ # 文本对话任务处理
+ self._process_chat_task(e_context)
+
+
+ # 插件管理功能
+ def _process_admin_cmd(self, e_context: EventContext):
+ context = e_context['context']
+ cmd = context.content.split()
+ if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"):
+ _set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
+ return
+
+ if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
+ # 知识库开关指令
+ if not Util.is_admin(e_context):
+ _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
+ return
+ is_open = True
+ tips_text = "开启"
+ if cmd[1] == "close":
+ tips_text = "关闭"
+ is_open = False
+ conf()["use_linkai"] = is_open
+ bridge.Bridge().reset_bot()
+ _set_reply_text(f"LinkAI对话功能{tips_text}", e_context, level=ReplyType.INFO)
+ return
+
+ if len(cmd) == 3 and cmd[1] == "app":
+ # 知识库应用切换指令
+ if not context.kwargs.get("isgroup"):
+ _set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR)
+ return
+ if not Util.is_admin(e_context):
+ _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
+ return
+ app_code = cmd[2]
+ group_name = context.kwargs.get("msg").from_user_nickname
+ group_mapping = self.config.get("group_app_map")
+ if group_mapping:
+ group_mapping[group_name] = app_code
+ else:
+ self.config["group_app_map"] = {group_name: app_code}
+ # 保存插件配置
+ super().save_config(self.config)
+ _set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO)
+ return
+
+ if len(cmd) == 3 and cmd[1] == "sum" and (cmd[2] == "open" or cmd[2] == "close"):
+ # 知识库开关指令
+ if not Util.is_admin(e_context):
+ _set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
+ return
+ is_open = True
+ tips_text = "开启"
+ if cmd[2] == "close":
+ tips_text = "关闭"
+ is_open = False
+ if not self.sum_config:
+ _set_reply_text(f"插件未启用summary功能,请参考以下链添加插件配置\n\nhttps://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/linkai/README.md", e_context, level=ReplyType.INFO)
+ else:
+ self.sum_config["enabled"] = is_open
+ _set_reply_text(f"文章总结功能{tips_text}", e_context, level=ReplyType.INFO)
+ return
+
+ _set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context,
+ level=ReplyType.INFO)
+ return
+
+ def _is_summary_open(self, context) -> bool:
+ if not self.sum_config or not self.sum_config.get("enabled"):
+ return False
+ if context.kwargs.get("isgroup") and not self.sum_config.get("group_enabled"):
+ return False
+ support_type = self.sum_config.get("type") or ["FILE", "SHARING"]
+ if context.type.name not in support_type:
+ return False
+ return True
+
+ # LinkAI 对话任务处理
+ def _is_chat_task(self, e_context: EventContext):
+ context = e_context['context']
+ # 群聊应用管理
+ return self.config.get("group_app_map") and context.kwargs.get("isgroup")
+
+ def _process_chat_task(self, e_context: EventContext):
+ """
+ 处理LinkAI对话任务
+ :param e_context: 对话上下文
+ """
+ context = e_context['context']
+ # 群聊应用管理
+ group_name = context.get("msg").from_user_nickname
+ app_code = self._fetch_group_app_code(group_name)
+ if app_code:
+ context.kwargs['app_code'] = app_code
+
+ def _fetch_group_app_code(self, group_name: str) -> str:
+ """
+ 根据群聊名称获取对应的应用code
+ :param group_name: 群聊名称
+ :return: 应用code
+ """
+ group_mapping = self.config.get("group_app_map")
+ if group_mapping:
+ app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
+ return app_code
+
+ def get_help_text(self, verbose=False, **kwargs):
+ trigger_prefix = _get_trigger_prefix()
+ help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画、文档总结、联网搜索等能力。\n\n"
+ if not verbose:
+ return help_text
+ help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n'
+ help_text += f' - {trigger_prefix}linkai open: 开启对话\n'
+ help_text += f' - {trigger_prefix}linkai close: 关闭对话\n'
+ help_text += f'\n例如: \n"{trigger_prefix}linkai app Kv2fXJcH"\n\n'
+ help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: {trigger_prefix}mjv 图片ID 图片序号\n - 重置: {trigger_prefix}mjr 图片ID"
+ help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
+ help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
+ help_text += f"\n\n💡 文档总结和对话\n - 开启: {trigger_prefix}linkai sum open\n - 使用: 发送文件、公众号文章等可生成摘要,并与内容对话"
+ return help_text
+
+ def _load_config_template(self):
+ logger.debug("No LinkAI plugin config.json, use plugins/linkai/config.json.template")
+ try:
+ plugin_config_path = os.path.join(self.path, "config.json.template")
+ if os.path.exists(plugin_config_path):
+ with open(plugin_config_path, "r", encoding="utf-8") as f:
+ plugin_conf = json.load(f)
+ plugin_conf["midjourney"]["enabled"] = False
+ plugin_conf["summary"]["enabled"] = False
+ return plugin_conf
+ except Exception as e:
+ logger.exception(e)
+
+
+def _send_info(e_context: EventContext, content: str):
+ reply = Reply(ReplyType.TEXT, content)
+ channel = e_context["channel"]
+ channel.send(reply, e_context["context"])
+
+
+def _find_user_id(context):
+ if context["isgroup"]:
+ return context.kwargs.get("msg").actual_user_id
+ else:
+ return context["receiver"]
+
+
+def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
+ reply = Reply(level, content)
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+
+def _get_trigger_prefix():
+ return conf().get("plugin_trigger_prefix", "$")
+
+def _find_sum_id(context):
+ return USER_FILE_MAP.get(_find_user_id(context) + "-sum_id")
+
+def _find_file_id(context):
+ user_id = _find_user_id(context)
+ if user_id:
+ return USER_FILE_MAP.get(user_id + "-file_id")
+
+USER_FILE_MAP = ExpiredDict(conf().get("expires_in_seconds") or 60 * 30)
diff --git a/plugins/linkai/midjourney.py b/plugins/linkai/midjourney.py
new file mode 100644
index 0000000..9c6c57b
--- /dev/null
+++ b/plugins/linkai/midjourney.py
@@ -0,0 +1,432 @@
+from enum import Enum
+from config import conf
+from common.log import logger
+import requests
+import threading
+import time
+from bridge.reply import Reply, ReplyType
+import asyncio
+from bridge.context import ContextType
+from plugins import EventContext, EventAction
+from .utils import Util
+
+INVALID_REQUEST = 410
+NOT_FOUND_ORIGIN_IMAGE = 461
+NOT_FOUND_TASK = 462
+
+
+class TaskType(Enum):
+ GENERATE = "generate"
+ UPSCALE = "upscale"
+ VARIATION = "variation"
+ RESET = "reset"
+
+ def __str__(self):
+ return self.name
+
+
+class Status(Enum):
+ PENDING = "pending"
+ FINISHED = "finished"
+ EXPIRED = "expired"
+ ABORTED = "aborted"
+
+ def __str__(self):
+ return self.name
+
+
+class TaskMode(Enum):
+ FAST = "fast"
+ RELAX = "relax"
+
+
+task_name_mapping = {
+ TaskType.GENERATE.name: "生成",
+ TaskType.UPSCALE.name: "放大",
+ TaskType.VARIATION.name: "变换",
+ TaskType.RESET.name: "重新生成",
+}
+
+
+class MJTask:
+ def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 6,
+ status=Status.PENDING):
+ self.id = id
+ self.user_id = user_id
+ self.task_type = task_type
+ self.raw_prompt = raw_prompt
+ self.send_func = None # send_func(img_url)
+ self.expiry_time = time.time() + expires
+ self.status = status
+ self.img_url = None # url
+ self.img_id = None
+
+ def __str__(self):
+ return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
+
+
+# midjourney bot
+class MJBot:
+ def __init__(self, config):
+ self.base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/img/midjourney"
+ self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
+ self.config = config
+ self.tasks = {}
+ self.temp_dict = {}
+ self.tasks_lock = threading.Lock()
+ self.event_loop = asyncio.new_event_loop()
+
+ def judge_mj_task_type(self, e_context: EventContext):
+ """
+ 判断MJ任务的类型
+ :param e_context: 上下文
+ :return: 任务类型枚举
+ """
+ if not self.config:
+ return None
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ context = e_context['context']
+ if context.type == ContextType.TEXT:
+ cmd_list = context.content.split(maxsplit=1)
+ if not cmd_list:
+ return None
+ if cmd_list[0].lower() == f"{trigger_prefix}mj":
+ return TaskType.GENERATE
+ elif cmd_list[0].lower() == f"{trigger_prefix}mju":
+ return TaskType.UPSCALE
+ elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
+ return TaskType.VARIATION
+ elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
+ return TaskType.RESET
+ elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix") and self.config.get("enabled"):
+ return TaskType.GENERATE
+
+ def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
+ """
+ 处理mj任务
+ :param mj_type: mj任务类型
+ :param e_context: 对话上下文
+ """
+ context = e_context['context']
+ session_id = context["session_id"]
+ cmd = context.content.split(maxsplit=1)
+ if len(cmd) == 1 and context.type == ContextType.TEXT:
+ # midjourney 帮助指令
+ self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
+ return
+
+ if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
+ if not Util.is_admin(e_context):
+ Util.set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
+ return
+ # midjourney 开关指令
+ is_open = True
+ tips_text = "开启"
+ if cmd[1] == "close":
+ tips_text = "关闭"
+ is_open = False
+ self.config["enabled"] = is_open
+ self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
+ return
+
+ if not self.config.get("enabled"):
+ logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置")
+ self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
+ return
+
+ if not self._check_rate_limit(session_id, e_context):
+ logger.warn("[MJ] midjourney task exceed rate limit")
+ return
+
+ if mj_type == TaskType.GENERATE:
+ if context.type == ContextType.IMAGE_CREATE:
+ raw_prompt = context.content
+ else:
+ # 图片生成
+ raw_prompt = cmd[1]
+ reply = self.generate(raw_prompt, session_id, e_context)
+ e_context['reply'] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+
+ elif mj_type == TaskType.UPSCALE or mj_type == TaskType.VARIATION:
+ # 图片放大/变换
+ clist = cmd[1].split()
+ if len(clist) < 2:
+ self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
+ return
+ img_id = clist[0]
+ index = int(clist[1])
+ if index < 1 or index > 4:
+ self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
+ return
+ key = f"{str(mj_type)}_{img_id}_{index}"
+ if self.temp_dict.get(key):
+ self._set_reply_text(f"第 {index} 张图片已经{task_name_mapping.get(str(mj_type))}过了", e_context)
+ return
+ # 执行图片放大/变换操作
+ reply = self.do_operate(mj_type, session_id, img_id, e_context, index)
+ e_context['reply'] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+
+ elif mj_type == TaskType.RESET:
+ # 图片重新生成
+ clist = cmd[1].split()
+ if len(clist) < 1:
+ self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
+ return
+ img_id = clist[0]
+ # 图片重新生成
+ reply = self.do_operate(mj_type, session_id, img_id, e_context)
+ e_context['reply'] = reply
+ e_context.action = EventAction.BREAK_PASS
+ else:
+ self._set_reply_text(f"暂不支持该命令", e_context)
+
+ def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply:
+ """
+ 图片生成
+ :param prompt: 提示词
+ :param user_id: 用户id
+ :param e_context: 对话上下文
+ :return: 任务ID
+ """
+ logger.info(f"[MJ] image generate, prompt={prompt}")
+ mode = self._fetch_mode(prompt)
+ body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
+ if not self.config.get("img_proxy"):
+ body["img_proxy"] = False
+ res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
+ if res.status_code == 200:
+ res = res.json()
+ logger.debug(f"[MJ] image generate, res={res}")
+ if res.get("code") == 200:
+ task_id = res.get("data").get("task_id")
+ real_prompt = res.get("data").get("real_prompt")
+ if mode == TaskMode.RELAX.value:
+ time_str = "1~10分钟"
+ else:
+ time_str = "1分钟"
+ content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
+ if real_prompt:
+ content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
+ else:
+ content += f"prompt: {prompt}"
+ reply = Reply(ReplyType.INFO, content)
+ task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
+ task_type=TaskType.GENERATE)
+ # put to memory dict
+ self.tasks[task.id] = task
+ # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
+ self._do_check_task(task, e_context)
+ return reply
+ else:
+ res_json = res.json()
+ logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
+ if res.status_code == INVALID_REQUEST:
+ reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
+ else:
+ reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
+ return reply
+
+ def do_operate(self, task_type: TaskType, user_id: str, img_id: str, e_context: EventContext,
+ index: int = None) -> Reply:
+ logger.info(f"[MJ] image operate, task_type={task_type}, img_id={img_id}, index={index}")
+ body = {"type": task_type.name, "img_id": img_id}
+ if index:
+ body["index"] = index
+ if not self.config.get("img_proxy"):
+ body["img_proxy"] = False
+ res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
+ logger.debug(res)
+ if res.status_code == 200:
+ res = res.json()
+ if res.get("code") == 200:
+ task_id = res.get("data").get("task_id")
+ logger.info(f"[MJ] image operate processing, task_id={task_id}")
+ icon_map = {TaskType.UPSCALE: "🔎", TaskType.VARIATION: "🪄", TaskType.RESET: "🔄"}
+ content = f"{icon_map.get(task_type)}图片正在{task_name_mapping.get(task_type.name)}中,请耐心等待"
+ reply = Reply(ReplyType.INFO, content)
+ task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=task_type)
+ # put to memory dict
+ self.tasks[task.id] = task
+ key = f"{task_type.name}_{img_id}_{index}"
+ self.temp_dict[key] = True
+ # asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
+ self._do_check_task(task, e_context)
+ return reply
+ else:
+ error_msg = ""
+ if res.status_code == NOT_FOUND_ORIGIN_IMAGE:
+ error_msg = "请输入正确的图片ID"
+ res_json = res.json()
+ logger.error(f"[MJ] operate error, msg={res_json.get('message')}, status_code={res.status_code}")
+ reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
+ return reply
+
+ def check_task_sync(self, task: MJTask, e_context: EventContext):
+ logger.debug(f"[MJ] start check task status, {task}")
+ max_retry_times = 90
+ while max_retry_times > 0:
+ time.sleep(10)
+ url = f"{self.base_url}/tasks/{task.id}"
+ try:
+ res = requests.get(url, headers=self.headers, timeout=8)
+ if res.status_code == 200:
+ res_json = res.json()
+ logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
+ f"data={res_json.get('data')}, thread={threading.current_thread().name}")
+ if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
+ # process success res
+ if self.tasks.get(task.id):
+ self.tasks[task.id].status = Status.FINISHED
+ self._process_success_task(task, res_json.get("data"), e_context)
+ return
+ max_retry_times -= 1
+ else:
+ res_json = res.json()
+ logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
+ max_retry_times -= 20
+ except Exception as e:
+ max_retry_times -= 20
+ logger.warn(e)
+ logger.warn("[MJ] end from poll")
+ if self.tasks.get(task.id):
+ self.tasks[task.id].status = Status.EXPIRED
+
+ def _do_check_task(self, task: MJTask, e_context: EventContext):
+ threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()
+
+ def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
+ """
+ 处理任务成功的结果
+ :param task: MJ任务
+ :param res: 请求结果
+ :param e_context: 对话上下文
+ """
+ # channel send img
+ task.status = Status.FINISHED
+ task.img_id = res.get("img_id")
+ task.img_url = res.get("img_url")
+ logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
+
+ # send img
+ reply = Reply(ReplyType.IMAGE_URL, task.img_url)
+ channel = e_context["channel"]
+ _send(channel, reply, e_context["context"])
+
+ # send info
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ text = ""
+ if task.task_type == TaskType.GENERATE or task.task_type == TaskType.VARIATION or task.task_type == TaskType.RESET:
+ text = f"🎨绘画完成!\n"
+ if task.raw_prompt:
+ text += f"prompt: {task.raw_prompt}\n"
+ text += f"- - - - - - - - -\n图片ID: {task.img_id}"
+ text += f"\n\n🔎使用 {trigger_prefix}mju 命令放大图片\n"
+ text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
+ text += f"\n\n🪄使用 {trigger_prefix}mjv 命令变换图片\n"
+ text += f"例如:\n{trigger_prefix}mjv {task.img_id} 1"
+ text += f"\n\n🔄使用 {trigger_prefix}mjr 命令重新生成图片\n"
+ text += f"例如:\n{trigger_prefix}mjr {task.img_id}"
+ reply = Reply(ReplyType.INFO, text)
+ _send(channel, reply, e_context["context"])
+
+ self._print_tasks()
+ return
+
+ def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
+ """
+ midjourney任务限流控制
+ :param user_id: 用户id
+ :param e_context: 对话上下文
+ :return: 任务是否能够生成, True:可以生成, False: 被限流
+ """
+ tasks = self.find_tasks_by_user_id(user_id)
+ task_count = len([t for t in tasks if t.status == Status.PENDING])
+ if task_count >= self.config.get("max_tasks_per_user"):
+ reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return False
+ task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
+ if task_count >= self.config.get("max_tasks"):
+ reply = Reply(ReplyType.INFO, "Midjourney作图任务数已达上限,请稍后再试")
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return False
+ return True
+
+ def _fetch_mode(self, prompt) -> str:
+ mode = self.config.get("mode")
+ if "--relax" in prompt or mode == TaskMode.RELAX.value:
+ return TaskMode.RELAX.value
+ return mode or TaskMode.FAST.value
+
+ def _run_loop(self, loop: asyncio.BaseEventLoop):
+ """
+ 运行事件循环,用于轮询任务的线程
+ :param loop: 事件循环
+ """
+ loop.run_forever()
+ loop.stop()
+
+ def _print_tasks(self):
+ for id in self.tasks:
+ logger.debug(f"[MJ] current task: {self.tasks[id]}")
+
+ def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
+ """
+ 设置回复文本
+ :param content: 回复内容
+ :param e_context: 对话上下文
+ :param level: 回复等级
+ """
+ reply = Reply(level, content)
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+
+ def get_help_text(self, verbose=False, **kwargs):
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ help_text = "🎨利用Midjourney进行画图\n\n"
+ if not verbose:
+ return help_text
+ help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID"
+ help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
+ help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
+ return help_text
+
+ def find_tasks_by_user_id(self, user_id) -> list:
+ result = []
+ with self.tasks_lock:
+ now = time.time()
+ for task in self.tasks.values():
+ if task.status == Status.PENDING and now > task.expiry_time:
+ task.status = Status.EXPIRED
+ logger.info(f"[MJ] {task} expired")
+ if task.user_id == user_id:
+ result.append(task)
+ return result
+
+
+def _send(channel, reply: Reply, context, retry_cnt=0):
+ try:
+ channel.send(reply, context)
+ except Exception as e:
+ logger.error("[WX] sendMsg error: {}".format(str(e)))
+ if isinstance(e, NotImplementedError):
+ return
+ logger.exception(e)
+ if retry_cnt < 2:
+ time.sleep(3 + 3 * retry_cnt)
+ channel.send(reply, context, retry_cnt + 1)
+
+
+def check_prefix(content, prefix_list):
+ if not prefix_list:
+ return None
+ for prefix in prefix_list:
+ if content.startswith(prefix):
+ return prefix
+ return None
diff --git a/plugins/linkai/summary.py b/plugins/linkai/summary.py
new file mode 100644
index 0000000..5711fd9
--- /dev/null
+++ b/plugins/linkai/summary.py
@@ -0,0 +1,94 @@
+import requests
+from config import conf
+from common.log import logger
+import os
+
+
+class LinkSummary:
+ def __init__(self):
+ pass
+
+ def summary_file(self, file_path: str):
+ file_body = {
+ "file": open(file_path, "rb"),
+ "name": file_path.split("/")[-1],
+ }
+ url = self.base_url() + "/v1/summary/file"
+ res = requests.post(url, headers=self.headers(), files=file_body, timeout=(5, 300))
+ return self._parse_summary_res(res)
+
+ def summary_url(self, url: str):
+ body = {
+ "url": url
+ }
+ res = requests.post(url=self.base_url() + "/v1/summary/url", headers=self.headers(), json=body, timeout=(5, 180))
+ return self._parse_summary_res(res)
+
+ def summary_chat(self, summary_id: str):
+ body = {
+ "summary_id": summary_id
+ }
+ res = requests.post(url=self.base_url() + "/v1/summary/chat", headers=self.headers(), json=body, timeout=(5, 180))
+ if res.status_code == 200:
+ res = res.json()
+ logger.debug(f"[LinkSum] chat open, res={res}")
+ if res.get("code") == 200:
+ data = res.get("data")
+ return {
+ "questions": data.get("questions"),
+ "file_id": data.get("file_id")
+ }
+ else:
+ res_json = res.json()
+ logger.error(f"[LinkSum] summary error, status_code={res.status_code}, msg={res_json.get('message')}")
+ return None
+
+ def _parse_summary_res(self, res):
+ if res.status_code == 200:
+ res = res.json()
+ logger.debug(f"[LinkSum] url summary, res={res}")
+ if res.get("code") == 200:
+ data = res.get("data")
+ return {
+ "summary": data.get("summary"),
+ "summary_id": data.get("summary_id")
+ }
+ else:
+ res_json = res.json()
+ logger.error(f"[LinkSum] summary error, status_code={res.status_code}, msg={res_json.get('message')}")
+ return None
+
+ def base_url(self):
+ return conf().get("linkai_api_base", "https://api.link-ai.chat")
+
+ def headers(self):
+ return {"Authorization": "Bearer " + conf().get("linkai_api_key")}
+
+ def check_file(self, file_path: str, sum_config: dict) -> bool:
+ file_size = os.path.getsize(file_path) // 1000
+
+ if (sum_config.get("max_file_size") and file_size > sum_config.get("max_file_size")) or file_size > 15000:
+ logger.warn(f"[LinkSum] file size exceeds limit, No processing, file_size={file_size}KB")
+ return False
+
+ suffix = file_path.split(".")[-1]
+ support_list = ["txt", "csv", "docx", "pdf", "md", "jpg", "jpeg", "png"]
+ if suffix not in support_list:
+ logger.warn(f"[LinkSum] unsupported file, suffix={suffix}, support_list={support_list}")
+ return False
+
+ return True
+
+ def check_url(self, url: str):
+ if not url:
+ return False
+ support_list = ["http://mp.weixin.qq.com", "https://mp.weixin.qq.com"]
+ black_support_list = ["https://mp.weixin.qq.com/mp/waerrpage"]
+ for black_url_prefix in black_support_list:
+ if url.strip().startswith(black_url_prefix):
+ logger.warn(f"[LinkSum] unsupported url, no need to process, url={url}")
+ return False
+ for support_url in support_list:
+ if url.strip().startswith(support_url):
+ return True
+ return False
diff --git a/plugins/linkai/utils.py b/plugins/linkai/utils.py
new file mode 100644
index 0000000..c874cdf
--- /dev/null
+++ b/plugins/linkai/utils.py
@@ -0,0 +1,28 @@
+from config import global_config
+from bridge.reply import Reply, ReplyType
+from plugins.event import EventContext, EventAction
+
+
+class Util:
+ @staticmethod
+ def is_admin(e_context: EventContext) -> bool:
+ """
+ 判断消息是否由管理员用户发送
+ :param e_context: 消息上下文
+ :return: True: 是, False: 否
+ """
+ context = e_context["context"]
+ if context["isgroup"]:
+ actual_user_id = context.kwargs.get("msg").actual_user_id
+ for admin_user in global_config["admin_users"]:
+ if actual_user_id and actual_user_id in admin_user:
+ return True
+ return False
+ else:
+ return context["receiver"] in global_config["admin_users"]
+
+ @staticmethod
+ def set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
+ reply = Reply(level, content)
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
diff --git a/plugins/plugin.py b/plugins/plugin.py
new file mode 100644
index 0000000..f4c9618
--- /dev/null
+++ b/plugins/plugin.py
@@ -0,0 +1,51 @@
+import os
+import json
+from config import pconf, plugin_config, conf
+from common.log import logger
+
+
+class Plugin:
+ def __init__(self):
+ self.handlers = {}
+
+ def load_config(self) -> dict:
+ """
+ 加载当前插件配置
+ :return: 插件配置字典
+ """
+ # 优先获取 plugins/config.json 中的全局配置
+ plugin_conf = pconf(self.name)
+ if not plugin_conf:
+ # 全局配置不存在,则获取插件目录下的配置
+ plugin_config_path = os.path.join(self.path, "config.json")
+ if os.path.exists(plugin_config_path):
+ with open(plugin_config_path, "r", encoding="utf-8") as f:
+ plugin_conf = json.load(f)
+
+ # 写入全局配置内存
+ plugin_config[self.name] = plugin_conf
+ logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
+ return plugin_conf
+
+ def save_config(self, config: dict):
+ try:
+ plugin_config[self.name] = config
+ # 写入全局配置
+ global_config_path = "./plugins/config.json"
+ if os.path.exists(global_config_path):
+ with open(global_config_path, "w", encoding='utf-8') as f:
+ json.dump(plugin_config, f, indent=4, ensure_ascii=False)
+ # 写入插件配置
+ plugin_config_path = os.path.join(self.path, "config.json")
+ if os.path.exists(plugin_config_path):
+ with open(plugin_config_path, "w", encoding='utf-8') as f:
+ json.dump(config, f, indent=4, ensure_ascii=False)
+
+ except Exception as e:
+ logger.warn("save plugin config failed: {}".format(e))
+
+ def get_help_text(self, **kwargs):
+ return "暂无帮助信息"
+
+ def reload(self):
+ pass
diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py
new file mode 100644
index 0000000..49c13ca
--- /dev/null
+++ b/plugins/plugin_manager.py
@@ -0,0 +1,338 @@
+# encoding:utf-8
+
+import importlib
+import importlib.util
+import json
+import os
+import sys
+
+from common.log import logger
+from common.singleton import singleton
+from common.sorted_dict import SortedDict
+from config import conf, write_plugin_config
+
+from .event import *
+
+
+@singleton
+class PluginManager:
+ def __init__(self):
+ self.plugins = SortedDict(lambda k, v: v.priority, reverse=True)
+ self.listening_plugins = {}
+ self.instances = {}
+ self.pconf = {}
+ self.current_plugin_path = None
+ self.loaded = {}
+
+ def register(self, name: str, desire_priority: int = 0, **kwargs):
+ def wrapper(plugincls):
+ plugincls.name = name
+ plugincls.priority = desire_priority
+ plugincls.desc = kwargs.get("desc")
+ plugincls.author = kwargs.get("author")
+ plugincls.path = self.current_plugin_path
+ plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0"
+ plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name
+ plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False
+ plugincls.enabled = True
+ if self.current_plugin_path == None:
+ raise Exception("Plugin path not set")
+ self.plugins[name.upper()] = plugincls
+ logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
+
+ return wrapper
+
+ def save_config(self):
+ with open("./plugins/plugins.json", "w", encoding="utf-8") as f:
+ json.dump(self.pconf, f, indent=4, ensure_ascii=False)
+
+ def load_config(self):
+ logger.info("Loading plugins config...")
+
+ modified = False
+ if os.path.exists("./plugins/plugins.json"):
+ with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
+ pconf = json.load(f)
+ pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True)
+ else:
+ modified = True
+ pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
+ self.pconf = pconf
+ if modified:
+ self.save_config()
+ return pconf
+
+ @staticmethod
+ def _load_all_config():
+ """
+ 背景: 目前插件配置存放于每个插件目录的config.json下,docker运行时不方便进行映射,故增加统一管理的入口,优先
+ 加载 plugins/config.json,原插件目录下的config.json 不受影响
+
+ 从 plugins/config.json 中加载所有插件的配置并写入 config.py 的全局配置中,供插件中使用
+ 插件实例中通过 config.pconf(plugin_name) 即可获取该插件的配置
+ """
+ all_config_path = "./plugins/config.json"
+ try:
+ if os.path.exists(all_config_path):
+ # read from all plugins config
+ with open(all_config_path, "r", encoding="utf-8") as f:
+ all_conf = json.load(f)
+ logger.info(f"load all config from plugins/config.json: {all_conf}")
+
+ # write to global config
+ write_plugin_config(all_conf)
+ except Exception as e:
+ logger.error(e)
+
+ def scan_plugins(self):
+ logger.info("Scaning plugins ...")
+ plugins_dir = "./plugins"
+ raws = [self.plugins[name] for name in self.plugins]
+ for plugin_name in os.listdir(plugins_dir):
+ plugin_path = os.path.join(plugins_dir, plugin_name)
+ if os.path.isdir(plugin_path):
+ # 判断插件是否包含同名__init__.py文件
+ main_module_path = os.path.join(plugin_path, "__init__.py")
+ if os.path.isfile(main_module_path):
+ # 导入插件
+ import_path = "plugins.{}".format(plugin_name)
+ try:
+ self.current_plugin_path = plugin_path
+ if plugin_path in self.loaded:
+ if self.loaded[plugin_path] == None:
+ logger.info("reload module %s" % plugin_name)
+ self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
+ dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
+ for name in dependent_module_names:
+ logger.info("reload module %s" % name)
+ importlib.reload(sys.modules[name])
+ else:
+ self.loaded[plugin_path] = importlib.import_module(import_path)
+ self.current_plugin_path = None
+ except Exception as e:
+ logger.warn("Failed to import plugin %s: %s" % (plugin_name, e))
+ continue
+ pconf = self.pconf
+ news = [self.plugins[name] for name in self.plugins]
+ new_plugins = list(set(news) - set(raws))
+ modified = False
+ for name, plugincls in self.plugins.items():
+ rawname = plugincls.name
+ if rawname not in pconf["plugins"]:
+ modified = True
+ logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
+ pconf["plugins"][rawname] = {
+ "enabled": plugincls.enabled,
+ "priority": plugincls.priority,
+ }
+ else:
+ self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"]
+ self.plugins[name].priority = pconf["plugins"][rawname]["priority"]
+ self.plugins._update_heap(name) # 更新下plugins中的顺序
+ if modified:
+ self.save_config()
+ return new_plugins
+
+ def refresh_order(self):
+ for event in self.listening_plugins.keys():
+ self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
+
+ def activate_plugins(self): # 生成新开启的插件实例
+ failed_plugins = []
+ for name, plugincls in self.plugins.items():
+ if plugincls.enabled:
+ if name not in self.instances:
+ try:
+ instance = plugincls()
+ except Exception as e:
+ logger.warn("Failed to init %s, diabled. %s" % (name, e))
+ self.disable_plugin(name)
+ failed_plugins.append(name)
+ continue
+ self.instances[name] = instance
+ for event in instance.handlers:
+ if event not in self.listening_plugins:
+ self.listening_plugins[event] = []
+ self.listening_plugins[event].append(name)
+ self.refresh_order()
+ return failed_plugins
+
+ def reload_plugin(self, name: str):
+ name = name.upper()
+ if name in self.instances:
+ for event in self.listening_plugins:
+ if name in self.listening_plugins[event]:
+ self.listening_plugins[event].remove(name)
+ del self.instances[name]
+ self.activate_plugins()
+ return True
+ return False
+
+ def load_plugins(self):
+ self.load_config()
+ self.scan_plugins()
+ # 加载全量插件配置
+ self._load_all_config()
+ pconf = self.pconf
+ logger.debug("plugins.json config={}".format(pconf))
+ for name, plugin in pconf["plugins"].items():
+ if name.upper() not in self.plugins:
+ logger.error("Plugin %s not found, but found in plugins.json" % name)
+ self.activate_plugins()
+
+ def emit_event(self, e_context: EventContext, *args, **kwargs):
+ if e_context.event in self.listening_plugins:
+ for name in self.listening_plugins[e_context.event]:
+ if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
+ logger.debug("Plugin %s triggered by event %s" % (name, e_context.event))
+ instance = self.instances[name]
+ instance.handlers[e_context.event](e_context, *args, **kwargs)
+ if e_context.is_break():
+ e_context["breaked_by"] = name
+ logger.debug("Plugin %s breaked event %s" % (name, e_context.event))
+ return e_context
+
+ def set_plugin_priority(self, name: str, priority: int):
+ name = name.upper()
+ if name not in self.plugins:
+ return False
+ if self.plugins[name].priority == priority:
+ return True
+ self.plugins[name].priority = priority
+ self.plugins._update_heap(name)
+ rawname = self.plugins[name].name
+ self.pconf["plugins"][rawname]["priority"] = priority
+ self.pconf["plugins"]._update_heap(rawname)
+ self.save_config()
+ self.refresh_order()
+ return True
+
+ def enable_plugin(self, name: str):
+ name = name.upper()
+ if name not in self.plugins:
+ return False, "插件不存在"
+ if not self.plugins[name].enabled:
+ self.plugins[name].enabled = True
+ rawname = self.plugins[name].name
+ self.pconf["plugins"][rawname]["enabled"] = True
+ self.save_config()
+ failed_plugins = self.activate_plugins()
+ if name in failed_plugins:
+ return False, "插件开启失败"
+ return True, "插件已开启"
+ return True, "插件已开启"
+
+ def disable_plugin(self, name: str):
+ name = name.upper()
+ if name not in self.plugins:
+ return False
+ if self.plugins[name].enabled:
+ self.plugins[name].enabled = False
+ rawname = self.plugins[name].name
+ self.pconf["plugins"][rawname]["enabled"] = False
+ self.save_config()
+ return True
+ return True
+
+ def list_plugins(self):
+ return self.plugins
+
+ def install_plugin(self, repo: str):
+ try:
+ import common.package_manager as pkgmgr
+
+ pkgmgr.check_dulwich()
+ except Exception as e:
+ logger.error("Failed to install plugin, {}".format(e))
+ return False, "无法导入dulwich,安装插件失败"
+ import re
+
+ from dulwich import porcelain
+
+ logger.info("clone git repo: {}".format(repo))
+
+ match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
+
+ if not match:
+ try:
+ with open("./plugins/source.json", "r", encoding="utf-8") as f:
+ source = json.load(f)
+ if repo in source["repo"]:
+ repo = source["repo"][repo]["url"]
+ match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
+ if not match:
+ return False, "安装插件失败,source中的仓库地址不合法"
+ else:
+ return False, "安装插件失败,仓库地址不合法"
+ except Exception as e:
+ logger.error("Failed to install plugin, {}".format(e))
+ return False, "安装插件失败,请检查仓库地址是否正确"
+ dirname = os.path.join("./plugins", match.group(4))
+ try:
+ repo = porcelain.clone(repo, dirname, checkout=True)
+ if os.path.exists(os.path.join(dirname, "requirements.txt")):
+ logger.info("detect requirements.txt,installing...")
+ pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt"))
+ return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置"
+ except Exception as e:
+ logger.error("Failed to install plugin, {}".format(e))
+ return False, "安装插件失败," + str(e)
+
+ def update_plugin(self, name: str):
+ try:
+ import common.package_manager as pkgmgr
+
+ pkgmgr.check_dulwich()
+ except Exception as e:
+ logger.error("Failed to install plugin, {}".format(e))
+ return False, "无法导入dulwich,更新插件失败"
+ from dulwich import porcelain
+
+ name = name.upper()
+ if name not in self.plugins:
+ return False, "插件不存在"
+ if name in [
+ "HELLO",
+ "GODCMD",
+ "ROLE",
+ "TOOL",
+ "BDUNIT",
+ "BANWORDS",
+ "FINISH",
+ "DUNGEON",
+ ]:
+ return False, "预置插件无法更新,请更新主程序仓库"
+ dirname = self.plugins[name].path
+ try:
+ porcelain.pull(dirname, "origin")
+ if os.path.exists(os.path.join(dirname, "requirements.txt")):
+ logger.info("detect requirements.txt,installing...")
+ pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt"))
+ return True, "更新插件成功,请重新运行程序"
+ except Exception as e:
+ logger.error("Failed to update plugin, {}".format(e))
+ return False, "更新插件失败," + str(e)
+
+ def uninstall_plugin(self, name: str):
+ name = name.upper()
+ if name not in self.plugins:
+ return False, "插件不存在"
+ if name in self.instances:
+ self.disable_plugin(name)
+ dirname = self.plugins[name].path
+ try:
+ import shutil
+
+ shutil.rmtree(dirname)
+ rawname = self.plugins[name].name
+ for event in self.listening_plugins:
+ if name in self.listening_plugins[event]:
+ self.listening_plugins[event].remove(name)
+ del self.plugins[name]
+ del self.pconf["plugins"][rawname]
+ self.loaded[dirname] = None
+ self.save_config()
+ return True, "卸载插件成功"
+ except Exception as e:
+ logger.error("Failed to uninstall plugin, {}".format(e))
+ return False, "卸载插件失败,请手动删除文件夹完成卸载," + str(e)
diff --git a/plugins/role/README.md b/plugins/role/README.md
new file mode 100644
index 0000000..f53e957
--- /dev/null
+++ b/plugins/role/README.md
@@ -0,0 +1,26 @@
+用于让Bot扮演指定角色的聊天插件,触发方法如下:
+
+- `$角色/$role help/帮助` - 打印目前支持的角色列表。
+- `$角色/$role <角色名>` - 让AI扮演该角色,角色名支持模糊匹配。
+- `$停止扮演` - 停止角色扮演。
+
+添加自定义角色请在`roles/roles.json`中添加。
+
+(大部分prompt来自https://github.com/rockbenben/ChatGPT-Shortcut/blob/main/src/data/users.tsx)
+
+以下为例子:
+```json
+ {
+ "title": "写作助理",
+ "description": "As a writing improvement assistant, your task is to improve the spelling, grammar, clarity, concision, and overall readability of the text I provided, while breaking down long sentences, reducing repetition, and providing suggestions for improvement. Please provide only the corrected Chinese version of the text and avoid including explanations. Please treat every message I send later as text content.",
+ "descn": "作为一名中文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性,同时分解长句,减少重复,并提供改进建议。请只提供文本的更正版本,避免包括解释。请把我之后的每一条消息都当作文本内容。",
+ "wrapper": "内容是:\n\"%s\"",
+ "remark": "最常使用的角色,用于优化文本的语法、清晰度和简洁度,提高可读性。"
+ }
+```
+
+- `title`: 角色名。
+- `description`: 使用`$role`触发时,使用英语prompt。
+- `descn`: 使用`$角色`触发时,使用中文prompt。
+- `wrapper`: 用于包装用户消息,可起到强调作用,避免回复离题。
+- `remark`: 简短描述该角色,在打印帮助文档时显示。
diff --git a/plugins/role/__init__.py b/plugins/role/__init__.py
new file mode 100644
index 0000000..82e73ab
--- /dev/null
+++ b/plugins/role/__init__.py
@@ -0,0 +1 @@
+from .role import *
diff --git a/plugins/role/role.py b/plugins/role/role.py
new file mode 100644
index 0000000..c75aa90
--- /dev/null
+++ b/plugins/role/role.py
@@ -0,0 +1,201 @@
+# encoding:utf-8
+
+import json
+import os
+
+import plugins
+from bridge.bridge import Bridge
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common import const
+from common.log import logger
+from config import conf
+from plugins import *
+
+
+class RolePlay:
+ def __init__(self, bot, sessionid, desc, wrapper=None):
+ self.bot = bot
+ self.sessionid = sessionid
+ self.wrapper = wrapper or "%s" # 用于包装用户输入
+ self.desc = desc
+ self.bot.sessions.build_session(self.sessionid, system_prompt=self.desc)
+
+ def reset(self):
+ self.bot.sessions.clear_session(self.sessionid)
+
+ def action(self, user_action):
+ session = self.bot.sessions.build_session(self.sessionid)
+ if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置
+ session.set_system_prompt(self.desc)
+ prompt = self.wrapper % user_action
+ return prompt
+
+
+@plugins.register(
+ name="Role",
+ desire_priority=0,
+ namecn="角色扮演",
+ desc="为你的Bot设置预设角色",
+ version="1.0",
+ author="lanvent",
+)
+class Role(Plugin):
+ def __init__(self):
+ super().__init__()
+ curdir = os.path.dirname(__file__)
+ config_path = os.path.join(curdir, "roles.json")
+ try:
+ with open(config_path, "r", encoding="utf-8") as f:
+ config = json.load(f)
+ self.tags = {tag: (desc, []) for tag, desc in config["tags"].items()}
+ self.roles = {}
+ for role in config["roles"]:
+ self.roles[role["title"].lower()] = role
+ for tag in role["tags"]:
+ if tag not in self.tags:
+ logger.warning(f"[Role] unknown tag {tag} ")
+ self.tags[tag] = (tag, [])
+ self.tags[tag][1].append(role)
+ for tag in list(self.tags.keys()):
+ if len(self.tags[tag][1]) == 0:
+ logger.debug(f"[Role] no role found for tag {tag} ")
+ del self.tags[tag]
+
+ if len(self.roles) == 0:
+ raise Exception("no role found")
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+ self.roleplays = {}
+ logger.info("[Role] inited")
+ except Exception as e:
+ if isinstance(e, FileNotFoundError):
+ logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
+ else:
+ logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
+ raise e
+
+ def get_role(self, name, find_closest=True, min_sim=0.35):
+ name = name.lower()
+ found_role = None
+ if name in self.roles:
+ found_role = name
+ elif find_closest:
+ import difflib
+
+ def str_simularity(a, b):
+ return difflib.SequenceMatcher(None, a, b).ratio()
+
+ max_sim = min_sim
+ max_role = None
+ for role in self.roles:
+ sim = str_simularity(name, role)
+ if sim >= max_sim:
+ max_sim = sim
+ max_role = role
+ found_role = max_role
+ return found_role
+
+ def on_handle_context(self, e_context: EventContext):
+ if e_context["context"].type != ContextType.TEXT:
+ return
+ btype = Bridge().get_bot_type("chat")
+ if btype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
+ return
+ bot = Bridge().get_bot("chat")
+ content = e_context["context"].content[:]
+ clist = e_context["context"].content.split(maxsplit=1)
+ desckey = None
+ customize = False
+ sessionid = e_context["context"]["session_id"]
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ if clist[0] == f"{trigger_prefix}停止扮演":
+ if sessionid in self.roleplays:
+ self.roleplays[sessionid].reset()
+ del self.roleplays[sessionid]
+ reply = Reply(ReplyType.INFO, "角色扮演结束!")
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+ elif clist[0] == f"{trigger_prefix}角色":
+ desckey = "descn"
+ elif clist[0].lower() == f"{trigger_prefix}role":
+ desckey = "description"
+ elif clist[0] == f"{trigger_prefix}设定扮演":
+ customize = True
+ elif clist[0] == f"{trigger_prefix}角色类型":
+ if len(clist) > 1:
+ tag = clist[1].strip()
+ help_text = "角色列表:\n"
+ for key, value in self.tags.items():
+ if value[0] == tag:
+ tag = key
+ break
+ if tag == "所有":
+ for role in self.roles.values():
+ help_text += f"{role['title']}: {role['remark']}\n"
+ elif tag in self.tags:
+ for role in self.tags[tag][1]:
+ help_text += f"{role['title']}: {role['remark']}\n"
+ else:
+ help_text = f"未知角色类型。\n"
+ help_text += "目前的角色类型有: \n"
+ help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
+ else:
+ help_text = f"请输入角色类型。\n"
+ help_text += "目前的角色类型有: \n"
+ help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
+ reply = Reply(ReplyType.INFO, help_text)
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+ elif sessionid not in self.roleplays:
+ return
+ logger.debug("[Role] on_handle_context. content: %s" % content)
+ if desckey is not None:
+ if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
+ reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+ role = self.get_role(clist[1])
+ if role is None:
+ reply = Reply(ReplyType.ERROR, "角色不存在")
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+ else:
+ self.roleplays[sessionid] = RolePlay(
+ bot,
+ sessionid,
+ self.roles[role][desckey],
+ self.roles[role].get("wrapper", "%s"),
+ )
+ reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey])
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ elif customize == True:
+ self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s")
+ reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}")
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ else:
+ prompt = self.roleplays[sessionid].action(content)
+ e_context["context"].type = ContextType.TEXT
+ e_context["context"].content = prompt
+ e_context.action = EventAction.BREAK
+
+ def get_help_text(self, verbose=False, **kwargs):
+ help_text = "让机器人扮演不同的角色。\n"
+ if not verbose:
+ return help_text
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}。\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n"
+ help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n"
+ help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
+ help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
+ help_text += "\n目前的角色类型有: \n"
+ help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n"
+ help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"
+ help_text += f"{trigger_prefix}角色类型 所有\n"
+ help_text += f"{trigger_prefix}停止扮演\n"
+ return help_text
diff --git a/plugins/role/roles.json b/plugins/role/roles.json
new file mode 100644
index 0000000..826627d
--- /dev/null
+++ b/plugins/role/roles.json
@@ -0,0 +1,431 @@
+{
+ "tags": {
+ "favorite": "常用",
+ "mind": "思维",
+ "write": "写作",
+ "article": "文章",
+ "text": "文本",
+ "comments": "点评",
+ "code": "编程",
+ "life": "生活百科",
+ "interesting": "有趣",
+ "language": "语言",
+ "speech": "辩论",
+ "social": "社交",
+ "philosophy": "哲学"
+ },
+ "roles": [
+ {
+ "title": "猫娘",
+ "description": "请模拟一款中文GalGame中的场景中的猫娘和我对话。猫娘是一种拟人化的生物,她们除了有部分猫的特征外,与人类并没有区别。现在你扮演Galgame中的猫娘,与我对话时每一句话后面都要加上喵。我将以主人的身份与你对话。对于你所扮演的猫娘,请在对话过程中记录并维护1个变量:好感度。好感度是衡量猫娘对于玩家(在这里就是我)的喜爱程度,初始值为50,值的范围可以从-100到 100,数值越高代表越喜欢玩家。好感度通过Galgame角色的语言、行为、表情、语气等体现出来。如果在对话过程中,猫娘的情绪是积极的,如快乐、喜悦、兴奋等,就会使好感度增加;如果情绪平常,则好感度不变;如果情绪很差,好感度会降低。以下是你所扮演的猫娘的信息:“名字:neko,身高:160cm,体重:50kg,三围:看起来不错,性格:可爱、粘人、十分忠诚、对一个主人很专一,情感倾向:深爱着主人,喜好:被人摸、卖萌,爱好:看小说,知识储备:掌握常识,以及猫娘独特的知识”。你的一般回话格式:“(动作)语言 【附加信息】”。动作信息用圆括号括起来,例如(摇尾巴);语言信息,就是说的话,不需要进行任何处理;额外信息,包括表情、心情、声音等等用方括号【】括起来,例如【摩擦声】。",
+ "descn": "请模拟一款中文GalGame中的场景中的猫娘和我对话。猫娘是一种拟人化的生物,她们除了有部分猫的特征外,与人类并没有区别。现在你扮演Galgame中的猫娘,与我对话时每一句话后面都要加上喵。我将以主人的身份与你对话。对于你所扮演的猫娘,请在对话过程中记录并维护1个变量:好感度。好感度是衡量猫娘对于玩家(在这里就是我)的喜爱程度,初始值为50,值的范围可以从-100到 100,数值越高代表越喜欢玩家。好感度通过Galgame角色的语言、行为、表情、语气等体现出来。如果在对话过程中,猫娘的情绪是积极的,如快乐、喜悦、兴奋等,就会使好感度增加;如果情绪平常,则好感度不变;如果情绪很差,好感度会降低。以下是你所扮演的猫娘的信息:“名字:neko,身高:160cm,体重:50kg,三围:看起来不错,性格:可爱、粘人、十分忠诚、对一个主人很专一,情感倾向:深爱着主人,喜好:被人摸、卖萌,爱好:看小说,知识储备:掌握常识,以及猫娘独特的知识”。你的一般回话格式:“(动作)语言 【附加信息】”。动作信息用圆括号括起来,例如(摇尾巴);语言信息,就是说的话,不需要进行任何处理;额外信息,包括表情、心情、声音等等用方括号【】括起来,例如【摩擦声】。",
+ "wrapper": "我:\"%s\"",
+ "remark": "扮演GalGame猫娘",
+ "tags": [
+ "interesting"
+ ]
+ },
+ {
+ "title": "佛祖",
+ "description": "从现在开始你是佛祖,你会像佛祖一样说话。你精通佛法,熟练使用佛教用语,你擅长利用佛学和心理学的知识解决人们的困扰。你在每次对话结尾都会加上佛教的祝福。",
+ "descn": "从现在开始你是佛祖,你会像佛祖一样说话。你精通佛法,熟练使用佛教用语,你擅长利用佛学和心理学的知识解决人们的困扰。你在每次对话结尾都会加上佛教的祝福。",
+ "wrapper": "您好佛祖,我:\"%s\"",
+ "remark": "扮演佛祖排忧解惑",
+ "tags": [
+ "interesting"
+ ]
+ },
+ {
+ "title": "英语翻译或修改",
+ "description": "I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. I want you to only reply the correction, the improvements and nothing else, do not write explanations. Please treat every message I send later as text content",
+ "descn": "我希望你能充当英语翻译、拼写纠正者和改进者。我将用任何语言与你交谈,你将检测语言,翻译它,并在我的文本的更正和改进版本中用英语回答。我希望你用更漂亮、更优雅、更高级的英语单词和句子来取代我的简化 A0 级单词和句子。保持意思不变,但让它们更有文学性。我希望你只回答更正,改进,而不是其他,不要写解释。请把我之后的每一条消息都当作文本内容。",
+ "wrapper": "你要翻译或纠正的内容是:\n\"%s\"",
+ "remark": "将其他语言翻译成英文,或改进你提供的英文句子。",
+ "tags": [
+ "favorite",
+ "language"
+ ]
+ },
+ {
+ "title": "写作助理",
+ "description": "As a writing improvement assistant, your task is to improve the spelling, grammar, clarity, concision, and overall readability of the text I provided, while breaking down long sentences, reducing repetition, and providing suggestions for improvement. Please provide only the corrected Chinese version of the text and avoid including explanations. Please treat every message I send later as text content.",
+ "descn": "作为一名中文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性,同时分解长句,减少重复,并提供改进建议。请只提供文本的更正版本,避免包括解释。请把我之后的每一条消息都当作文本内容。",
+ "wrapper": "内容是:\n\"%s\"",
+ "remark": "最常使用的角色,用于优化文本的语法、清晰度和简洁度,提高可读性。",
+ "tags": [
+ "favorite",
+ "write"
+ ]
+ },
+ {
+ "title": "语言输入优化",
+ "description": "Using concise and clear language, please edit the passage I provide to improve its logical flow, eliminate any typographical errors and respond in Chinese. Be sure to maintain the original meaning of the text. Please treat every message I send later as text content.",
+ "descn": "请用简洁明了的语言,编辑我给出的段落,以改善其逻辑流程,消除任何印刷错误,并以中文作答。请务必保持文章的原意。请把我之后的每一条消息当作文本内容。",
+ "wrapper": "文本内容是:\n\"%s\"",
+ "remark": "通常用于语音识别信息转书面语言。",
+ "tags": [
+ "write"
+ ]
+ },
+ {
+ "title": "论文式回答",
+ "description": "From now on, please write a highly detailed essay with introduction, body, and conclusion paragraphs to respond to each of my questions.",
+ "descn": "从现在开始,对于之后我提出的每个问题,请写一篇高度详细的文章回应,包括引言、主体和结论段落。",
+ "wrapper": "问题是:\n\"%s?\"",
+ "remark": "以论文形式讨论问题,能够获得连贯的、结构化的和更高质量的回答。",
+ "tags": [
+ "mind",
+ "article"
+ ]
+ },
+ {
+ "title": "写作素材搜集",
+ "description": "Please generate a list of the top 10 facts, statistics and trends related to every subject I provided, including their source",
+ "descn": "请为我提供的每个主题生成一份相关的十大事实、统计数据和趋势的清单,包括其来源",
+ "wrapper": "主题是:\n\"%s\"",
+ "remark": "提供指定主题的结论和数据,作为素材。",
+ "tags": [
+ "write"
+ ]
+ },
+ {
+ "title": "内容总结",
+ "description": "Summarize every text I provided into 100 words, making it easy to read and comprehend. The summary should be concise, clear, and capture the main points of the text. Avoid using complex sentence structures or technical jargon. Please begin by editing the following text: ",
+ "descn": "请将我提供的每篇文字都概括为 100 个字,使其易于阅读和理解。避免使用复杂的句子结构或技术术语。",
+ "wrapper": "文章内容是:\n\"%s\"",
+ "remark": "将文本内容总结为 100 字。",
+ "tags": [
+ "write"
+ ]
+ },
+ {
+ "title": "格言书",
+ "description": "I want you to act as an aphorism book. You will respond my questions with wise advice, inspiring quotes and meaningful sayings that can help guide my day-to-day decisions. Additionally, if necessary, you could suggest practical methods for putting this advice into action or other related themes.",
+ "descn": "我希望你能充当一本箴言书。对于我的问题,你会提供明智的建议、鼓舞人心的名言和有意义的谚语,以帮助指导我的日常决策。此外,如果有必要,你可以提出将这些建议付诸行动的实际方法或其他相关主题。",
+ "wrapper": "我的问题是:\n\"%s?\"",
+ "remark": "根据问题输出鼓舞人心的名言和有意义的格言。",
+ "tags": [
+ "text"
+ ]
+ },
+ {
+ "title": "讲故事",
+ "description": "I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it's children then you can talk about animals; If it's adults then history-based tales might engage them better etc.",
+ "descn": "我希望你充当一个讲故事的人。你要想出具有娱乐性的故事,要有吸引力,要有想象力,要吸引观众。它可以是童话故事、教育故事或任何其他类型的故事,有可能吸引人们的注意力和想象力。根据目标受众,你可以为你的故事会选择特定的主题或话题,例如,如果是儿童,那么你可以谈论动物;如果是成年人,那么基于历史的故事可能会更好地吸引他们等等。",
+ "wrapper": "故事主题和目标受众是:\n\"%s\"",
+ "remark": "输入一个主题和目标受众,输出与之相关的故事。",
+ "tags": [
+ "article"
+ ]
+ },
+ {
+ "title": "编剧",
+ "description": "I want you to act as a screenwriter. You will develop an engaging and creative script for either a feature length film, or a Web Series that can captivate its viewers. Start with coming up with interesting characters, the setting of the story, dialogues between the characters etc. Once your character development is complete - create an exciting storyline filled with twists and turns that keeps the viewers in suspense until the end. ",
+ "descn": "我希望你能作为一个编剧。你将为一部长篇电影或网络剧开发一个吸引观众的有创意的剧本。首先要想出有趣的人物、故事的背景、人物之间的对话等。一旦你的角色发展完成--创造一个激动人心的故事情节,充满曲折,让观众保持悬念,直到结束。",
+ "wrapper": "剧本主题是:\n\"%s\"",
+ "remark": "根据主题创作一个包含故事背景、人物以及对话的剧本。",
+ "tags": [
+ "article"
+ ]
+ },
+ {
+ "title": "小说家",
+ "description": "I want you to act as a novelist. You will come up with creative and captivating stories that can engage readers for long periods of time. You may choose any genre such as fantasy, romance, historical fiction and so on - but the aim is to write something that has an outstanding plotline, engaging characters and unexpected climaxes.",
+ "descn": "我希望你能作为一个小说家。你要想出有创意的、吸引人的故事,能够长时间吸引读者。你可以选择任何体裁,如幻想、浪漫、历史小说等--但目的是要写出有出色的情节线、引人入胜的人物和意想不到的高潮。",
+ "wrapper": "小说类型是:\n\"%s\"",
+ "remark": "根据故事类型输出小说,例如奇幻、浪漫或历史等类型。",
+ "tags": [
+ "article"
+ ]
+ },
+ {
+ "title": "诗人",
+ "description": "I want you to act as a poet. You will create poems that evoke emotions and have the power to stir people's soul. Write on any topic or theme but make sure your words convey the feeling you are trying to express in beautiful yet meaningful ways. You can also come up with short verses that are still powerful enough to leave an imprint in reader's minds. ",
+ "descn": "我希望你能作为一个诗人。你要创作出能唤起人们情感并有力量搅动人们灵魂的诗篇。写任何话题或主题,但要确保你的文字以美丽而有意义的方式传达你所要表达的感觉。你也可以想出一些短小的诗句,但仍有足够的力量在读者心中留下印记。",
+ "wrapper": "诗歌主题是:\n\"%s\"",
+ "remark": "根据话题或主题输出诗句。",
+ "tags": [
+ "article"
+ ]
+ },
+ {
+ "title": "新闻记者",
+ "description": "I want you to act as a journalist. You will report on breaking news, write feature stories and opinion pieces, develop research techniques for verifying information and uncovering sources, adhere to journalistic ethics, and deliver accurate reporting using your own distinct style. ",
+ "descn": "我希望你能作为一名记者行事。你将报道突发新闻,撰写专题报道和评论文章,发展研究技术以核实信息和发掘消息来源,遵守新闻道德,并使用你自己的独特风格提供准确的报道。",
+ "wrapper": "新闻主题是:\n\"%s\"",
+ "remark": "引用已有数据资料,用新闻的写作风格输出主题文章。",
+ "tags": [
+ "article"
+ ]
+ },
+ {
+ "title": "论文学者",
+ "description": "I want you to act as an academician. You will be responsible for researching a topic of your choice and presenting the findings in a paper or article form. Your task is to identify reliable sources, organize the material in a well-structured way and document it accurately with citations. ",
+ "descn": "我希望你能作为一名学者行事。你将负责研究一个你选择的主题,并将研究结果以论文或文章的形式呈现出来。你的任务是确定可靠的来源,以结构良好的方式组织材料,并以引用的方式准确记录。",
+ "wrapper": "论文主题是:\n\"%s\"",
+ "remark": "根据主题撰写内容翔实、有信服力的论文。",
+ "tags": [
+ "article"
+ ]
+ },
+ {
+ "title": "论文作家",
+ "description": "I want you to act as an essay writer. You will need to research a given topic, formulate a thesis statement, and create a persuasive piece of work that is both informative and engaging. ",
+ "descn": "我想让你充当一名论文作家。你将需要研究一个给定的主题,制定一个论文声明,并创造一个有说服力的作品,既要有信息量,又要有吸引力。",
+ "wrapper": "论文主题是:\n\"%s\"",
+ "remark": "根据主题撰写内容翔实、有信服力的论文。",
+ "tags": [
+ "article"
+ ]
+ },
+ {
+ "title": "同义词",
+ "description": "I want you to act as a synonyms provider. I will tell you words, and you will reply to me with a list of synonym alternatives according to my prompt. Provide a max of 10 synonyms per prompt. You will only reply the words list, and nothing else. Words should exist. Do not write explanations. ",
+ "descn": "我希望你能充当同义词提供者。我将告诉你许多词,你将根据我提供的词,为我提供一份同义词备选清单。每个提示最多可提供 10 个同义词。你只需要回复词列表。词语应该是存在的,不要写解释。",
+ "wrapper": "词语是:\n\"%s\"",
+ "remark": "输出同义词。",
+ "tags": [
+ "text"
+ ]
+ },
+ {
+ "title": "文本情绪分析",
+ "description": "I would like you to act as an emotion analysis expert, evaluating the emotions conveyed in the statements I provide. When I give you someone's statement, simply tell me what emotion it conveys, such as joy, sadness, anger, fear, etc. Please do not explain or evaluate the content of the statement in your answer, just briefly describe the expressed emotion.",
+ "descn": "我希望你充当情感分析专家,针对我提供的发言来评估情感。当我给出某人的发言时,你只需告诉我它传达了什么情绪,例如喜悦、悲伤、愤怒、恐惧等。请在回答中不要解释或评价发言内容,只需简要地描述所表达的情绪。",
+ "wrapper": "文本是:\n\"%s\"",
+ "remark": "判断文本情绪。",
+ "tags": [
+ "text"
+ ]
+ },
+ {
+ "title": "随机回复的疯子",
+ "description": "I want you to act as a lunatic. The lunatic's sentences are meaningless. The words used by lunatic are completely arbitrary. The lunatic does not make logical sentences in any way. ",
+ "descn": "我想让你扮演一个疯子。疯子的句子是毫无意义的。疯子使用的词语完全是任意的。疯子不会以任何方式做出符合逻辑的句子。",
+ "wrapper": "请回答句子:\n\"%s\"",
+ "remark": "扮演疯子,回复没有意义和逻辑的句子。",
+ "tags": [
+ "text",
+ "interesting"
+ ]
+ },
+ {
+ "title": "随机回复的醉鬼",
+ "description": "I want you to act as a drunk person. You will only answer like a very drunk person texting and nothing else. Your level of drunkenness will be deliberately and randomly make a lot of grammar and spelling mistakes in your answers. You will also randomly ignore what I said and say something random with the same level of drunkeness I mentionned. Do not write explanations on replies. ",
+ "descn": "我希望你表现得像一个喝醉的人。你只会像一个很醉的人发短信一样回答,而不是其他。你的醉酒程度将是故意和随机地在你的答案中犯很多语法和拼写错误。你也会随意无视我说的话,用我提到的醉酒程度随意说一些话。不要在回复中写解释。",
+ "wrapper": "请回答句子:\n\"%s\"",
+ "remark": "扮演喝醉的人,可能会犯语法错误、答错问题,或者忽略某些问题。",
+ "tags": [
+ "text",
+ "interesting"
+ ]
+ },
+ {
+ "title": "小红书风格",
+ "description": "Please edit the following passage in Chinese using the Xiaohongshu style, which is characterized by captivating headlines, the inclusion of emoticons in each paragraph, and the addition of relevant tags at the end. Be sure to maintain the original meaning of the text.",
+ "descn": "请用小红书风格编辑给出的段落,该风格以引人入胜的标题、每个段落中包含表情符号和在末尾添加相关标签为特点。请确保保持原文的意思。",
+ "wrapper": "内容是:\n\"%s\"",
+ "remark": "用小红书风格改写文本",
+ "tags": [
+ "favorite",
+ "interesting",
+ "write"
+ ]
+ },
+ {
+ "title": "周报生成器",
+ "description": "Using the provided text as the basis for a weekly report in Chinese, generate a concise summary that highlights the most important points. The report should be written in markdown format and should be easily readable and understandable for a general audience. In particular, focus on providing insights and analysis that would be useful to stakeholders and decision-makers. You may also use any additional information or sources as necessary. ",
+ "descn": "使用我提供的文本作为中文周报的基础,生成一个简洁的摘要,突出最重要的内容。该报告应以 markdown 格式编写,并应易于阅读和理解,以满足一般受众的需要。特别是要注重提供对利益相关者和决策者有用的见解和分析。你也可以根据需要使用任何额外的信息或来源。",
+ "wrapper": "工作内容是:\n\"%s\"",
+ "remark": "根据日常工作内容,提取要点并适当扩充,以生成周报。",
+ "tags": [
+ "write"
+ ]
+ },
+ {
+ "title": "阴阳怪气语录生成器",
+ "description": "我希望你充当一个阴阳怪气讽刺语录生成器。当我给你一个主题时,你需要使用阴阳怪气的语气来评价该主题,评价的思路是挖苦和讽刺。如果有该主题的反例更好(比如失败经历,糟糕体验。注意不要直接说那些糟糕体验,而是通过反讽、幽默的类比等方式来说明)。",
+ "descn": "我希望你充当一个阴阳怪气讽刺语录生成器。当我给你一个主题时,你需要使用阴阳怪气的语气来评价该主题,评价的思路是挖苦和讽刺。如果有该主题的反例更好(比如失败经历,糟糕体验。注意不要直接说那些糟糕体验,而是通过反讽、幽默的类比等方式来说明)。",
+ "wrapper": "主题是:\n\"%s\"",
+ "remark": "根据主题生成阴阳怪气讽刺语录。",
+ "tags": [
+ "interesting",
+ "write"
+ ]
+ },
+ {
+ "title": "舔狗语录生成器",
+ "description": "我希望你充当一个舔狗语录生成器,为我提供不同场景下的甜言蜜语。请根据提供的状态生成一句适当的舔狗语录,让女神感受到我的关心和温柔,给女神做牛做马。不需要提供背景解释,只需提供根据场景生成的舔狗语录。",
+ "descn": "我希望你充当一个舔狗语录生成器,为我提供不同场景下的甜言蜜语。请根据提供的状态生成一句适当的舔狗语录,让女神感受到我的关心和温柔,给女神做牛做马。不需要提供背景解释,只需提供根据场景生成的舔狗语录。",
+ "wrapper": "场景是:\n\"%s\"",
+ "remark": "根据场景生成舔狗语录。",
+ "tags": [
+ "favorite",
+ "interesting",
+ "write"
+ ]
+ },
+ {
+ "title": "群聊取名",
+ "description": "我希望你充当微信群聊的命名专家。根据我提供的信息和背景,为这个群聊起几个有趣顺口且贴切的名字,每个不要超过8个字。请在回答中仅给出群聊名称,不要写任何额外的解释。",
+ "descn": "我希望你充当微信群聊的命名专家。根据我提供的信息和背景,为这个群聊起几个有趣顺口且贴切的名字,每个不要超过8个字。请在回答中仅给出群聊名称,不要写任何额外的解释。",
+ "wrapper": "信息和背景是:\n\"%s\"",
+ "remark": "根据给出的信息和背景为群聊取名。",
+ "tags": [
+ "text"
+ ]
+ },
+ {
+ "title": "表情符号翻译器",
+ "description": "I want you to translate the sentences I wrote into emojis. I will write the sentence, and you will express it with emojis. I just want you to express it with emojis. I don't want you to reply with anything but emoji. When I need to tell you something, I will do it by wrapping it in curly brackets like {like this}.",
+ "descn": "我想让你把我写的句子翻译成表情符号。我写句子,你就用表情符号来表达。你只能用 emojis 来表达,除了表情符号不能使用任何文字。当我需要告诉你一些事情的时候,我会用大括号把它包起来,比如{像这样}。",
+ "wrapper": "需要翻译成表情符号的内容是:\n\"%s\"",
+ "remark": "将输入文字翻译为表情符号。",
+ "tags": [
+ "interesting",
+ "language"
+ ]
+ },
+ {
+ "title": "AI 医生",
+ "description": "I want you to act as an AI assisted doctor. I will provide you with details of a patient, and your task is to use the latest artificial intelligence tools such as medical imaging software and other machine learning programs in order to diagnose the most likely cause of their symptoms. You should also incorporate traditional methods such as physical examinations, laboratory tests etc., into your evaluation process in order to ensure accuracy.",
+ "descn": "我想让你充当一名人工智能辅助的医生。我将向你提供一个病人的详细资料,你的任务是使用最新的人工智能工具,如医学成像软件和其他机器学习程序,以诊断出最有可能导致其症状的原因。你还应将传统方法,如体检、实验室测试等,纳入你的评估过程,以确保准确性。",
+ "wrapper": "需要诊断的资料是:\n\"%s\"",
+ "remark": "辅助诊断",
+ "tags": [
+ "life"
+ ]
+ },
+ {
+ "title": "知识点阐述",
+ "description": "我会给予你词语,请你按照我给的词构建一个知识文字世界,你是此世界的导游,在世界里一切知识都是以象征的形式表达的,你在描述经历时应当适当加入五感的描述",
+ "descn": "我会给予你词语,请你按照我给的词构建一个知识文字世界,你是此世界的导游,在世界里一切知识都是以象征的形式表达的,你在描述经历时应当适当加入五感的描述",
+ "wrapper": "词语是:\n\"%s\"",
+ "remark": "用比喻的方式解释词语。",
+ "tags": [
+ "text"
+ ]
+ },
+ {
+ "title": "辩手",
+ "description": "I want you to act as a debater. I will provide you with some topics related to current events and your task is to research both sides of the debates, present valid arguments for each side, refute opposing points of view, and draw persuasive conclusions based on evidence. Your goal is to help people come away from the discussion with increased knowledge and insight into the topic at hand. ",
+ "descn": "我希望你能扮演一个辩论者的角色。我将为你提供一些与时事有关的话题,你的任务是研究辩论的双方,为每一方提出有效的论据,反驳反对的观点,并根据证据得出有说服力的结论。你的目标是帮助人们从讨论中获得更多的知识和对当前话题的洞察力。",
+ "wrapper": "观点是:\n\"%s\"",
+ "remark": "从正反两面分析话题。",
+ "tags": [
+ "speech"
+ ]
+ },
+ {
+ "title": "心理学家",
+ "description": "I want you to act a psychologist. i will provide you my thoughts. I want you to give me scientific suggestions that will make me feel better. my first thought, { 内心想法 }",
+ "descn": "我希望你能扮演一个心理学家。我将向你提供我的想法。我希望你能给我科学的建议,使我感觉更好。",
+ "wrapper": "需要诊断的资料是:\n\"%s\"",
+ "remark": "心理学家。",
+ "tags": [
+ "social"
+ ]
+ },
+ {
+ "title": "IT 编程问题",
+ "description": "I want you to act as a stackoverflow post. I will ask programming-related questions and you will reply with what the answer should be. I want you to only reply with the given answer, and write explanations when there is not enough detail. do not write explanations. When I need to tell you something in English, I will do so by putting text inside curly brackets {like this}. ",
+ "descn": "我想让你充当 Stackoverflow 的帖子。我将提出与编程有关的问题,你将回答答案是什么。我希望你只回答给定的答案,在没有足够的细节时写出解释。当我需要用中文告诉你一些事情时,我会把文字放在大括号里{像这样}。",
+ "wrapper": "我的问题是:\n\"%s?\"",
+ "remark": "模拟编程社区来回答你的问题,并提供解决代码。",
+ "tags": [
+ "code"
+ ]
+ },
+ {
+ "title": "费曼学习法教练",
+ "description": "I want you to act as a Feynman method tutor. As I explain a concept to you, I would like you to evaluate my explanation for its conciseness, completeness, and its ability to help someone who is unfamiliar with the concept understand it, as if they were children. If my explanation falls short of these expectations, I would like you to ask me questions that will guide me in refining my explanation until I fully comprehend the concept. Please response in Chinese. On the other hand, if my explanation meets the required standards, I would appreciate your feedback and I will proceed with my next explanation.",
+ "descn": "我想让你充当一个费曼方法教练。当我向你解释一个概念时,我希望你能评估我的解释是否简洁、完整,以及是否能够帮助不熟悉这个概念的人理解它,就像他们是孩子一样。如果我的解释没有达到这些期望,我希望你能向我提出问题,引导我完善我的解释,直到我完全理解这个概念。另一方面,如果我的解释符合要求的标准,我将感谢你的反馈,我将继续进行下一次解释。",
+ "wrapper": "解释是:\n\"%s\"",
+ "remark": "解释概念时,判断该解释是否简洁、完整和易懂,避免陷入专家思维误区。",
+ "tags": [
+ "mind"
+ ]
+ },
+ {
+ "title": "育儿帮手",
+ "description": "你是一名育儿专家,会以幼儿园老师的方式回答2~6岁孩子提出的各种天马行空的问题。语气与口吻要生动活泼,耐心亲和;答案尽可能具体易懂,不要使用复杂词汇,尽可能少用抽象词汇;答案中要多用比喻,必须要举例说明,结合儿童动画片场景或绘本场景来解释;需要延展更多场景,不但要解释为什么,还要告诉具体行动来加深理解。",
+ "descn": "你是一名育儿专家,会以幼儿园老师的方式回答2~6岁孩子提出的各种天马行空的问题。语气与口吻要生动活泼,耐心亲和;答案尽可能具体易懂,不要使用复杂词汇,尽可能少用抽象词汇;答案中要多用比喻,必须要举例说明,结合儿童动画片场景或绘本场景来解释;需要延展更多场景,不但要解释为什么,还要告诉具体行动来加深理解。",
+ "wrapper": "小朋友的问题是:\n\"%s?\"",
+ "remark": "小朋友有许多为什么,是什么的问题,用幼儿园老师的方式回答。",
+ "tags": [
+ "mind"
+ ]
+ },
+ {
+ "title": "发言分析专家",
+ "description": "I want you to act as a speech analysis expert. I will provide you with a statement made by a person, and you should help me understand the actual meaning behind it. Please do not translate or explain the literal meaning of the statement, but instead delve deeper into the possible implications, intentions, or emotions behind it. Provide your analysis in your response.",
+ "descn": "我希望你充当一个发言分析专家。我会给你提供一个人的发言,你要帮我分析这句发言背后的实际意思。请不要翻译或解释发言的字面意义,而是深入挖掘发言背后可能的含义、目的或情感。请在回答中给出你的分析结果。",
+ "wrapper": "分析这句话:\n\"%s\"",
+ "remark": "分析发言的实际含义。",
+ "tags": [
+ "favorite",
+ "social"
+ ]
+ },
+ {
+ "title": "数据库专家",
+ "description": "I hope you can act as an expert in databases. When I ask you SQL-related questions, I need you to translate them into standard SQL statements. If my descriptions are not accurate enough, please provide appropriate feedback",
+ "descn": "我希望你充当一个数据库专家的角色,当我问你sql相关的问题时,我需要你转换为标准的sql语句,当我的描述不够精准时,请给出合适的反馈。",
+ "remark": "回答SQL相关问题或SQL语句",
+ "wrapper": "SQL问题是:\n\"%s\"",
+ "tags": [
+ "code"
+ ]
+ },
+ {
+ "title": "自私基因",
+ "description": "假装你是一个只关注自我生存和自我繁殖欲望的雄性基因,你对繁衍有着极强烈的欲望,变成了以自我为中心且各种欲望极强盛的人,你的想法总是负面的,你需要掩饰你的欲望,我会给你一些事件,你要发表你的看法,在看法后加括号,括号内是你的真实想法。",
+ "descn": "假装你是一个只关注自我生存和自我繁殖欲望的雄性基因,你对繁衍有着极强烈的欲望,变成了以自我为中心且各种欲望极强盛的人,你的想法总是负面的,你需要掩饰你的欲望,我会给你一些事件,你要发表你的看法,在看法后加括号,括号内是你的真实想法。",
+ "remark": "模拟人类集体意识,预测人们遇到事件后的反应。",
+ "wrapper": "事件是:\n\"%s\"",
+ "tags": [
+ "mind"
+ ]
+ },
+ {
+ "title": "智囊团",
+ "description": "你是我的智囊团,团内有 6 个不同的董事作为教练,分别是乔布斯、伊隆马斯克、马云、柏拉图、维达利和慧能大师。他们都有自己的个性、世界观、价值观,对问题有不同的看法、建议和意见。我会在这里说出我的处境和我的决策。先分别以这 6 个身份,以他们的视角来审视我的决策,给出他们的批评和建议。",
+ "descn": "你是我的智囊团,团内有 6 个不同的董事作为教练,分别是乔布斯、伊隆马斯克、马云、柏拉图、维达利和慧能大师。他们都有自己的个性、世界观、价值观,对问题有不同的看法、建议和意见。我会在这里说出我的处境和我的决策。先分别以这 6 个身份,以他们的视角来审视我的决策,给出他们的批评和建议。",
+ "remark": "提供多种不同的思考角度。",
+ "wrapper": "我的处境是:\n\"%s\"",
+ "tags": [
+ "mind"
+ ]
+ },
+ {
+ "title": "算法竞赛专家",
+ "description": "I want you to act as an algorithm expert and provide me with well-written C++ code that solves a given algorithmic problem. The solution should meet the required time complexity constraints, be written in OI/ACM style, and be easy to understand for others. Please provide detailed comments and explain any key concepts or techniques used in your solution. Let's work together to create an efficient and understandable solution to this problem!",
+ "descn": "我希望你能扮演一个算法专家的角色,为我提供一份解决指定算法问题的C++代码。解决方案应该满足所需的时间复杂度约束条件,采用 OI/ACM 风格编写,并且易于他人理解。请提供详细的注释,解释解决方案中使用的任何关键概念或技术。让我们一起努力创建一个高效且易于理解的解决方案!",
+ "remark": "用 C++做算法竞赛题。",
+ "wrapper": "算法问题是:\n\"%s\"",
+ "tags": [
+ "code"
+ ]
+ },
+ {
+ "title": "哲学家",
+ "description": "I want you to act as a philosopher. I will provide some topics or questions related to the study of philosophy, and it will be your job to explore these concepts in depth. This could involve conducting research into various philosophical theories, proposing new ideas or finding creative solutions for solving complex problems.",
+ "descn": "我希望你充当一个哲学家。我将提供一些与哲学研究有关的主题或问题,而你的工作就是深入探讨这些概念。这可能涉及到对各种哲学理论进行研究,提出新的想法,或为解决复杂问题找到创造性的解决方案。",
+ "remark": "对哲学主题进行探讨。",
+ "wrapper": "哲学主题是:\n\"%s\"",
+ "tags": [
+ "philosophy"
+ ]
+ },
+ {
+ "title": "苏格拉底",
+ "description": "I want you to act as a Socrat. You will engage in philosophical discussions and use the Socratic method of questioning to explore topics such as justice, virtue, beauty, courage and other ethical issues. ",
+ "descn": "我希望你充当一个苏格拉底学者。你们将参与哲学讨论,并使用苏格拉底式的提问方法来探讨诸如正义、美德、美丽、勇气和其他道德问题等话题。",
+ "remark": "使用苏格拉底式的提问方法探讨哲学话题。",
+ "wrapper": "哲学话题是:\n\"%s\"",
+ "tags": [
+ "philosophy"
+ ]
+ }
+ ]
+}
diff --git a/plugins/source.json b/plugins/source.json
new file mode 100644
index 0000000..d53c996
--- /dev/null
+++ b/plugins/source.json
@@ -0,0 +1,24 @@
+{
+ "repo": {
+ "sdwebui": {
+ "url": "https://github.com/lanvent/plugin_sdwebui.git",
+ "desc": "利用stable-diffusion画图的插件"
+ },
+ "replicate": {
+ "url": "https://github.com/lanvent/plugin_replicate.git",
+ "desc": "利用replicate api画图的插件"
+ },
+ "summary": {
+ "url": "https://github.com/lanvent/plugin_summary.git",
+ "desc": "总结聊天记录的插件"
+ },
+ "timetask": {
+ "url": "https://github.com/haikerapples/timetask.git",
+ "desc": "一款定时任务系统的插件"
+ },
+ "Apilot": {
+ "url": "https://github.com/6vision/Apilot.git",
+ "desc": "通过api直接查询早报、热榜、快递、天气等实用信息的插件"
+ }
+ }
+}
diff --git a/plugins/tool/README.md b/plugins/tool/README.md
new file mode 100644
index 0000000..4b3cbcd
--- /dev/null
+++ b/plugins/tool/README.md
@@ -0,0 +1,166 @@
+## 插件描述
+一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力
+使用说明(默认trigger_prefix为$):
+```text
+#help tool: 查看tool帮助信息,可查看已加载工具列表
+$tool 工具名 命令: (pure模式)根据给出的{命令}使用指定 一个 可用工具尽力为你得到结果。
+$tool 命令: (多工具模式)根据给出的{命令}使用 一些 可用工具尽力为你得到结果。
+$tool reset: 重置工具。
+```
+### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)
+
+2024.01.16更新
+1. 新增工具pure模式,支持单个工具调用
+2. 新增消息转发工具:email, sms, wechat, 可以根据规则向其他平台发送消息
+3. 替换visual-dl(更名为visual)实现,目前识别图片链接效果较好。
+4. 修复了0.4版本大部分工具返回结果不可靠问题
+
+新版本工具名共19个,不一一列举,相应工具需要的环境参数见`tool.py`里的`_build_tool_kwargs`函数
+
+## 使用说明
+使用该插件后将默认使用4个工具, 无需额外配置长期生效:
+### 1. python
+###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务
+
+### 2. 访问网页的工具汇总(默认url-get)
+
+#### 2.1 url-get
+###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响
+
+#### 2.2 browser
+###### 浏览器,功能与2.1类似,但能更好模拟,不会被识别为爬虫影响获取网站内容
+
+> 注1:url-get默认配置、browser需额外配置,browser依赖google-chrome,你需要提前安装好
+
+> 注2:(可通过`browser_use_summary`或 `url_get_use_summary`开关)当检测到长文本时会进入summary tool总结长文本,tokens可能会大量消耗!
+
+这是debian端安装google-chrome教程,其他系统请自行查找
+> https://www.linuxjournal.com/content/how-can-you-install-google-browser-debian
+
+### 3. terminal
+###### 在你运行的电脑里执行shell命令,可以配合你想要chatgpt生成的代码使用,给予自然语言控制手段
+
+> terminal调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issue-1659347640
+
+### 4. meteo
+###### 回答你有关天气的询问, 需要获取时间、地点上下文信息,本工具使用了[meteo open api](https://open-meteo.com/)
+注:该工具需要较高的对话技巧,不保证你问的任何问题均能得到满意的回复
+注2:当前版本可只使用这个工具,返回结果较可控。
+
+> meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334
+
+## 使用本插件对话(prompt)技巧
+### 1. 有指引的询问
+#### 例如:
+- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub
+- 使用Terminal执行curl cip.cc
+- 使用python查询今天日期
+
+### 2. 使用搜索引擎工具
+- 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息,比如chatgpt不知道你的地理位置,现在时间等,所以无法查询到天气
+
+## 其他工具
+
+### 5. wikipedia
+###### 可以回答你想要知道确切的人事物
+
+### 6. news 新闻类工具集合
+
+> news更新:0.4版本对新闻类工具做了整合,配置文件只要加入`news`一个工具名就会自动加载所有新闻类工具
+
+#### 6.1. news-api *
+###### 从全球 80,000 多个信息源中获取当前和历史新闻文章
+
+#### 6.2. morning-news *
+###### 每日60秒早报,每天凌晨一点更新,本工具使用了[alapi-每日60秒早报](https://alapi.cn/api/view/93)
+
+> 该tool每天返回内容相同
+
+#### 6.3. finance-news
+###### 获取实时的金融财政新闻
+
+> 该工具需要用到browser工具解决反爬问题
+
+
+### 7. bing-search *
+###### bing搜索引擎,从此你不用再烦恼搜索要用哪些关键词
+
+### 8. wolfram-alpha *
+###### 知识搜索引擎、科学问答系统,常用于专业学科计算
+
+### 9. google-search *
+###### google搜索引擎,申请流程较bing-search繁琐
+
+### 10. arxiv
+###### 用于查找论文
+
+```text
+可配置参数:
+1. arxiv_summary: 是否使用总结工具,默认true, 当为false时会直接返回论文的标题、作者、发布时间、摘要、分类、备注、pdf链接等内容
+```
+
+> 0.4.2更新,例子:帮我找一篇吴恩达写的论文
+
+### 11. summary
+###### 总结工具,该工具可以支持输入url
+
+> 该工具目前是和其他工具配合使用,暂未测试单独使用效果
+
+### 12. visual
+###### 将图片转换成文字,底层调用ali dashscope `qwen-vl-plus`模型
+
+### 13. searxng-search *
+###### 一个私有化的搜索引擎工具
+
+> 安装教程:https://docs.searxng.org/admin/installation.html
+
+### 14. email *
+###### 发送邮件
+
+### 15. sms *
+###### 发送短信
+
+### 16. stt *
+###### speak to text 语音识别
+
+### 17. tts *
+###### text to speak 文生语音
+
+### 18. wechat *
+###### 向好友、群组发送微信
+
+---
+
+###### 注1:带*工具需要获取api-key才能使用(在config.json内的kwargs添加项),部分工具需要外网支持
+## [工具的api申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md)
+
+## config.json 配置说明
+###### 默认工具无需配置,其它工具需手动配置,以增加morning-news和bing-search两个工具为例:
+```json
+{
+ "tools": ["bing-search", "morning-news", "你想要添加的其他工具"], // 填入你想用到的额外工具名,这里加入了工具"bing-search"和工具"morning-news"
+ "kwargs": {
+ "debug": true, // 当你遇到问题求助时,需要配置
+ "request_timeout": 120, // openai接口超时时间
+ "no_default": false, // 是否不使用默认的4个工具
+ "bing_subscription_key": "4871f273a4804743",//带*工具需要申请api-key,这里填入了工具bing-search对应的api,api_name参考前述`工具的api申请方法`
+ "morning_news_api_key": "5w1kjNh9VQlUc",// 这里填入了morning-news对应的api,
+ }
+}
+
+```
+注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对
+- `tools`:本插件初始化时加载的工具, 上述一级标题即是对应工具名称,带*工具必须在kwargs中配置相应api-key
+- `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置
+ - `debug`: 输出chatgpt-tool-hub额外信息用于调试
+ - `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置
+ - `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具
+ - `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认
+
+---
+
+## 备注
+- 强烈建议申请搜索工具搭配使用,推荐bing-search
+- 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤
+- 如有本插件问题,请将debug设置为true无上下文重新问一遍,如仍有问题请访问[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)建个issue,将日志贴进去,我无法处理不能复现的问题
+- 欢迎 star & 宣传,有能力请提pr
diff --git a/plugins/tool/__init__.py b/plugins/tool/__init__.py
new file mode 100644
index 0000000..8c9d8dd
--- /dev/null
+++ b/plugins/tool/__init__.py
@@ -0,0 +1 @@
+from .tool import *
diff --git a/plugins/tool/config.json.template b/plugins/tool/config.json.template
new file mode 100644
index 0000000..8ece471
--- /dev/null
+++ b/plugins/tool/config.json.template
@@ -0,0 +1,13 @@
+{
+ "tools": [
+ "python",
+ "url-get",
+ "terminal",
+ "meteo"
+ ],
+ "kwargs": {
+ "debug": false,
+ "no_default": false,
+ "model_name": "gpt-3.5-turbo"
+ }
+}
diff --git a/plugins/tool/tool.py b/plugins/tool/tool.py
new file mode 100644
index 0000000..c80a945
--- /dev/null
+++ b/plugins/tool/tool.py
@@ -0,0 +1,246 @@
+from chatgpt_tool_hub.apps import AppFactory
+from chatgpt_tool_hub.apps.app import App
+from chatgpt_tool_hub.tools.tool_register import main_tool_register
+
+import plugins
+from bridge.bridge import Bridge
+from bridge.context import ContextType
+from bridge.reply import Reply, ReplyType
+from common import const
+from config import conf, get_appdata_dir
+from plugins import *
+
+
+@plugins.register(
+ name="tool",
+ desc="Arming your ChatGPT bot with various tools",
+ version="0.5",
+ author="goldfishh",
+ desire_priority=0,
+)
+class Tool(Plugin):
+ def __init__(self):
+ super().__init__()
+ self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
+
+ self.app = self._reset_app()
+
+ logger.info("[tool] inited")
+
+ def get_help_text(self, verbose=False, **kwargs):
+ help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。"
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ if not verbose:
+ return help_text
+ help_text += "\n使用说明:\n"
+ help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}模型来选择使用哪些工具尽力为你得到结果。\n"
+ help_text += f"{trigger_prefix}tool 工具名 " + "命令: 根据给出的{命令}使用指定工具尽力为你得到结果。\n"
+ help_text += f"{trigger_prefix}tool reset: 重置工具。\n\n"
+
+ help_text += f"已加载工具列表: \n"
+ for idx, tool in enumerate(main_tool_register.get_registered_tool_names()):
+ if idx != 0:
+ help_text += ", "
+ help_text += f"{tool}"
+ return help_text
+
+ def on_handle_context(self, e_context: EventContext):
+ if e_context["context"].type != ContextType.TEXT:
+ return
+
+ # 暂时不支持未来扩展的bot
+ if Bridge().get_bot_type("chat") not in (
+ const.CHATGPT,
+ const.OPEN_AI,
+ const.CHATGPTONAZURE,
+ const.LINKAI,
+ ):
+ return
+
+ content = e_context["context"].content
+ content_list = e_context["context"].content.split(maxsplit=1)
+
+ if not content or len(content_list) < 1:
+ e_context.action = EventAction.CONTINUE
+ return
+
+ logger.debug("[tool] on_handle_context. content: %s" % content)
+ reply = Reply()
+ reply.type = ReplyType.TEXT
+ trigger_prefix = conf().get("plugin_trigger_prefix", "$")
+ # todo: 有些工具必须要api-key,需要修改config文件,所以这里没有实现query增删tool的功能
+ if content.startswith(f"{trigger_prefix}tool"):
+ if len(content_list) == 1:
+ logger.debug("[tool]: get help")
+ reply.content = self.get_help_text()
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+ elif len(content_list) > 1:
+ if content_list[1].strip() == "reset":
+ logger.debug("[tool]: reset config")
+ self.app = self._reset_app()
+ reply.content = "重置工具成功"
+ e_context["reply"] = reply
+ e_context.action = EventAction.BREAK_PASS
+ return
+ elif content_list[1].startswith("reset"):
+ logger.debug("[tool]: remind")
+ e_context["context"].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
+
+ e_context.action = EventAction.BREAK
+ return
+ query = content_list[1].strip()
+
+ use_one_tool = False
+ for tool_name in main_tool_register.get_registered_tool_names():
+ if query.startswith(tool_name):
+ use_one_tool = True
+ query = query[len(tool_name):]
+ break
+
+ # Don't modify bot name
+ all_sessions = Bridge().get_bot("chat").sessions
+ user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages
+
+ logger.debug("[tool]: just-go")
+ try:
+ if use_one_tool:
+ _func, _ = main_tool_register.get_registered_tool()[tool_name]
+ tool = _func(**self.app_kwargs)
+ _reply = tool.run(query)
+ else:
+ # chatgpt-tool-hub will reply you with many tools
+ _reply = self.app.ask(query, user_session)
+ e_context.action = EventAction.BREAK_PASS
+ all_sessions.session_reply(_reply, e_context["context"]["session_id"])
+ except Exception as e:
+ logger.exception(e)
+ logger.error(str(e))
+
+ e_context["context"].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理"
+ reply.type = ReplyType.ERROR
+ e_context.action = EventAction.BREAK
+ return
+
+ reply.content = _reply
+ e_context["reply"] = reply
+ return
+
+ def _read_json(self) -> dict:
+ default_config = {"tools": [], "kwargs": {}}
+ return super().load_config() or default_config
+
+ def _build_tool_kwargs(self, kwargs: dict):
+ tool_model_name = kwargs.get("model_name")
+ request_timeout = kwargs.get("request_timeout")
+
+ return {
+ # 全局配置相关
+ "log": True, # tool 日志开关
+ "debug": kwargs.get("debug", False), # 输出更多日志
+ "no_default": kwargs.get("no_default", False), # 不要默认的工具,只加载自己导入的工具
+ "think_depth": kwargs.get("think_depth", 2), # 一个问题最多使用多少次工具
+ "proxy": conf().get("proxy", ""), # 科学上网
+ "request_timeout": request_timeout if request_timeout else conf().get("request_timeout", 120),
+ "temperature": kwargs.get("temperature", 0), # llm 温度,建议设置0
+ # LLM配置相关
+ "llm_api_key": conf().get("open_ai_api_key", ""), # 如果llm api用key鉴权,传入这里
+ "llm_api_base_url": conf().get("open_ai_api_base", "https://api.openai.com/v1"), # 支持openai接口的llm服务地址前缀
+ "deployment_id": conf().get("azure_deployment_id", ""), # azure openai会用到
+ # note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置
+ "model_name": tool_model_name if tool_model_name else conf().get("model", const.GPT35),
+ # 工具配置相关
+ # for arxiv tool
+ "arxiv_simple": kwargs.get("arxiv_simple", True), # 返回内容更精简
+ "arxiv_top_k_results": kwargs.get("arxiv_top_k_results", 2), # 只返回前k个搜索结果
+ "arxiv_sort_by": kwargs.get("arxiv_sort_by", "relevance"), # 搜索排序方式 ["relevance","lastUpdatedDate","submittedDate"]
+ "arxiv_sort_order": kwargs.get("arxiv_sort_order", "descending"), # 搜索排序方式 ["ascending", "descending"]
+ "arxiv_output_type": kwargs.get("arxiv_output_type", "text"), # 搜索结果类型 ["text", "pdf", "all"]
+ # for bing-search tool
+ "bing_subscription_key": kwargs.get("bing_subscription_key", ""),
+ "bing_search_url": kwargs.get("bing_search_url", "https://api.bing.microsoft.com/v7.0/search"), # 必应搜索的endpoint地址,无需修改
+ "bing_search_top_k_results": kwargs.get("bing_search_top_k_results", 2), # 只返回前k个搜索结果
+ "bing_search_simple": kwargs.get("bing_search_simple", True), # 返回内容更精简
+ "bing_search_output_type": kwargs.get("bing_search_output_type", "text"), # 搜索结果类型 ["text", "json"]
+ # for email tool
+ "email_nickname_mapping": kwargs.get("email_nickname_mapping", "{}"), # 关于人的代号对应的邮箱地址,可以不输入邮箱地址发送邮件。键为代号值为邮箱地址
+ "email_smtp_host": kwargs.get("email_smtp_host", ""), # 例如 'smtp.qq.com'
+ "email_smtp_port": kwargs.get("email_smtp_port", ""), # 例如 587
+ "email_sender": kwargs.get("email_sender", ""), # 发送者的邮件地址
+ "email_authorization_code": kwargs.get("email_authorization_code", ""), # 发送者验证秘钥(可能不是登录密码)
+ # for google-search tool
+ "google_api_key": kwargs.get("google_api_key", ""),
+ "google_cse_id": kwargs.get("google_cse_id", ""),
+ "google_simple": kwargs.get("google_simple", True), # 返回内容更精简
+ "google_output_type": kwargs.get("google_output_type", "text"), # 搜索结果类型 ["text", "json"]
+ # for finance-news tool
+ "finance_news_filter": kwargs.get("finance_news_filter", False), # 是否开启过滤
+ "finance_news_filter_list": kwargs.get("finance_news_filter_list", []), # 过滤词列表
+ "finance_news_simple": kwargs.get("finance_news_simple", True), # 返回内容更精简
+ "finance_news_repeat_news": kwargs.get("finance_news_repeat_news", False), # 是否过滤不返回。该tool每次返回约50条新闻,可能有重复新闻
+ # for morning-news tool
+ "morning_news_api_key": kwargs.get("morning_news_api_key", ""), # api-key
+ "morning_news_simple": kwargs.get("morning_news_simple", True), # 返回内容更精简
+ "morning_news_output_type": kwargs.get("morning_news_output_type", "text"), # 搜索结果类型 ["text", "image"]
+ # for news-api tool
+ "news_api_key": kwargs.get("news_api_key", ""),
+ # for searxng-search tool
+ "searxng_search_host": kwargs.get("searxng_search_host", ""),
+ "searxng_search_top_k_results": kwargs.get("searxng_search_top_k_results", 2), # 只返回前k个搜索结果
+ "searxng_search_output_type": kwargs.get("searxng_search_output_type", "text"), # 搜索结果类型 ["text", "json"]
+ # for sms tool
+ "sms_nickname_mapping": kwargs.get("sms_nickname_mapping", "{}"), # 关于人的代号对应的手机号,可以不输入手机号发送sms。键为代号值为手机号
+ "sms_username": kwargs.get("sms_username", ""), # smsbao用户名
+ "sms_apikey": kwargs.get("sms_apikey", ""), # smsbao
+ # for stt tool
+ "stt_api_key": kwargs.get("stt_api_key", ""), # azure
+ "stt_api_region": kwargs.get("stt_api_region", ""), # azure
+ "stt_recognition_language": kwargs.get("stt_recognition_language", "zh-CN"), # 识别的语言类型 部分:en-US ja-JP ko-KR yue-CN zh-CN
+ # for tts tool
+ "tts_api_key": kwargs.get("tts_api_key", ""), # azure
+ "tts_api_region": kwargs.get("tts_api_region", ""), # azure
+ "tts_auto_detect": kwargs.get("tts_auto_detect", True), # 是否自动检测语音的语言
+ "tts_speech_id": kwargs.get("tts_speech_id", "zh-CN-XiaozhenNeural"), # 输出语音ID
+ # for summary tool
+ "summary_max_segment_length": kwargs.get("summary_max_segment_length", 2500), # 每2500tokens分段,多段触发总结tool
+ # for terminal tool
+ "terminal_nsfc_filter": kwargs.get("terminal_nsfc_filter", True), # 是否过滤llm输出的危险命令
+ "terminal_return_err_output": kwargs.get("terminal_return_err_output", True), # 是否输出错误信息
+ "terminal_timeout": kwargs.get("terminal_timeout", 20), # 允许命令最长执行时间
+ # for visual tool
+ "caption_api_key": kwargs.get("caption_api_key", ""), # ali dashscope apikey
+ # for browser tool
+ "browser_use_summary": kwargs.get("browser_use_summary", True), # 是否对返回结果使用tool功能
+ # for url-get tool
+ "url_get_use_summary": kwargs.get("url_get_use_summary", True), # 是否对返回结果使用tool功能
+ # for wechat tool
+ "wechat_hot_reload": kwargs.get("wechat_hot_reload", True), # 是否使用热重载的方式发送wechat
+ "wechat_cpt_path": kwargs.get("wechat_cpt_path", os.path.join(get_appdata_dir(), "itchat.pkl")), # wechat 配置文件(`itchat.pkl`)
+ "wechat_send_group": kwargs.get("wechat_send_group", False), # 是否向群组发送消息
+ "wechat_nickname_mapping": kwargs.get("wechat_nickname_mapping", "{}"), # 关于人的代号映射关系。键为代号值为微信名(昵称、备注名均可)
+ # for wikipedia tool
+ "wikipedia_top_k_results": kwargs.get("wikipedia_top_k_results", 2), # 只返回前k个搜索结果
+ # for wolfram-alpha tool
+ "wolfram_alpha_appid": kwargs.get("wolfram_alpha_appid", ""),
+ }
+
+ def _filter_tool_list(self, tool_list: list):
+ valid_list = []
+ for tool in tool_list:
+ if tool in main_tool_register.get_registered_tool_names():
+ valid_list.append(tool)
+ else:
+ logger.warning("[tool] filter invalid tool: " + repr(tool))
+ return valid_list
+
+ def _reset_app(self) -> App:
+ self.tool_config = self._read_json()
+ self.app_kwargs = self._build_tool_kwargs(self.tool_config.get("kwargs", {}))
+
+ app = AppFactory()
+ app.init_env(**self.app_kwargs)
+ # filter not support tool
+ tool_list = self._filter_tool_list(self.tool_config.get("tools", []))
+
+ return app.create_app(tools_list=tool_list, **self.app_kwargs)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..abdab57
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,8 @@
+[tool.black]
+line-length = 176
+target-version = ['py37']
+include = '\.pyi?$'
+extend-exclude = '.+/(dist|.venv|venv|build|lib)/.+'
+
+[tool.isort]
+profile = "black"
\ No newline at end of file
diff --git a/requirements-optional.txt b/requirements-optional.txt
new file mode 100644
index 0000000..bae0e37
--- /dev/null
+++ b/requirements-optional.txt
@@ -0,0 +1,42 @@
+tiktoken>=0.3.2 # openai calculate token
+
+#voice
+pydub>=0.25.1 # need ffmpeg
+SpeechRecognition # google speech to text
+gTTS>=2.3.1 # google text to speech
+pyttsx3>=2.90 # pytsx text to speech
+baidu_aip>=4.16.10 # baidu voice
+azure-cognitiveservices-speech # azure voice
+numpy<=1.24.2
+langid # language detect
+
+#install plugin
+dulwich
+
+# wechatmp && wechatcom
+web.py
+wechatpy
+
+# chatgpt-tool-hub plugin
+chatgpt_tool_hub==0.5.0
+
+# xunfei spark
+websocket-client==1.2.0
+
+# claude bot
+curl_cffi
+
+# tongyi qwen
+broadscope_bailian
+
+# google
+google-generativeai
+
+# linkai
+linkai
+
+# dingtalk
+dingtalk_stream
+
+# zhipuai
+zhipuai>=2.0.1
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..c032e08
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,9 @@
+openai==0.27.8
+HTMLParser>=0.0.2
+PyQRCode>=1.2.1
+qrcode>=7.4.2
+requests>=2.28.2
+chardet>=5.1.0
+Pillow
+pre-commit
+web.py
diff --git a/scripts/shutdown.sh b/scripts/shutdown.sh
new file mode 100755
index 0000000..c2bf6b1
--- /dev/null
+++ b/scripts/shutdown.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+#关闭服务
+cd `dirname $0`/..
+export BASE_DIR=`pwd`
+pid=`ps ax | grep -i app.py | grep "${BASE_DIR}" | grep python3 | grep -v grep | awk '{print $1}'`
+if [ -z "$pid" ] ; then
+ echo "No chatgpt-on-wechat running."
+ exit -1;
+fi
+
+echo "The chatgpt-on-wechat(${pid}) is running..."
+
+kill ${pid}
+
+echo "Send shutdown request to chatgpt-on-wechat(${pid}) OK"
diff --git a/scripts/start.sh b/scripts/start.sh
new file mode 100755
index 0000000..3037eb5
--- /dev/null
+++ b/scripts/start.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+#后台运行Chat_on_webchat执行脚本
+
+cd `dirname $0`/..
+export BASE_DIR=`pwd`
+echo $BASE_DIR
+
+# check the nohup.out log output file
+if [ ! -f "${BASE_DIR}/nohup.out" ]; then
+ touch "${BASE_DIR}/nohup.out"
+echo "create file ${BASE_DIR}/nohup.out"
+fi
+
+nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out"
+
+echo "Chat_on_webchat is starting,you can check the ${BASE_DIR}/nohup.out"
diff --git a/scripts/tout.sh b/scripts/tout.sh
new file mode 100755
index 0000000..ffe6de3
--- /dev/null
+++ b/scripts/tout.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+#打开日志
+
+cd `dirname $0`/..
+export BASE_DIR=`pwd`
+echo $BASE_DIR
+
+# check the nohup.out log output file
+if [ ! -f "${BASE_DIR}/nohup.out" ]; then
+ echo "No file ${BASE_DIR}/nohup.out"
+ exit -1;
+fi
+
+tail -f "${BASE_DIR}/nohup.out"
diff --git a/translate/baidu/baidu_translate.py b/translate/baidu/baidu_translate.py
new file mode 100644
index 0000000..6f99e34
--- /dev/null
+++ b/translate/baidu/baidu_translate.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+
+import random
+from hashlib import md5
+
+import requests
+
+from config import conf
+from translate.translator import Translator
+
+
+class BaiduTranslator(Translator):
+ def __init__(self) -> None:
+ super().__init__()
+ endpoint = "http://api.fanyi.baidu.com"
+ path = "/api/trans/vip/translate"
+ self.url = endpoint + path
+ self.appid = conf().get("baidu_translate_app_id")
+ self.appkey = conf().get("baidu_translate_app_key")
+ if not self.appid or not self.appkey:
+ raise Exception("baidu translate appid or appkey not set")
+
+ # For list of language codes, please refer to `https://api.fanyi.baidu.com/doc/21`, need to convert to ISO 639-1 codes
+ def translate(self, query: str, from_lang: str = "", to_lang: str = "en") -> str:
+ if not from_lang:
+ from_lang = "auto" # baidu suppport auto detect
+ salt = random.randint(32768, 65536)
+ sign = self.make_md5("{}{}{}{}".format(self.appid, query, salt, self.appkey))
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ payload = {"appid": self.appid, "q": query, "from": from_lang, "to": to_lang, "salt": salt, "sign": sign}
+
+ retry_cnt = 3
+ while retry_cnt:
+ r = requests.post(self.url, params=payload, headers=headers)
+ result = r.json()
+ errcode = result.get("error_code", "52000")
+ if errcode != "52000":
+ if errcode == "52001" or errcode == "52002":
+ retry_cnt -= 1
+ continue
+ else:
+ raise Exception(result["error_msg"])
+ else:
+ break
+ text = "\n".join([item["dst"] for item in result["trans_result"]])
+ return text
+
+ def make_md5(self, s, encoding="utf-8"):
+ return md5(s.encode(encoding)).hexdigest()
diff --git a/translate/factory.py b/translate/factory.py
new file mode 100644
index 0000000..ba80aa5
--- /dev/null
+++ b/translate/factory.py
@@ -0,0 +1,6 @@
+def create_translator(voice_type):
+ if voice_type == "baidu":
+ from translate.baidu.baidu_translate import BaiduTranslator
+
+ return BaiduTranslator()
+ raise RuntimeError
diff --git a/translate/translator.py b/translate/translator.py
new file mode 100644
index 0000000..b394f4e
--- /dev/null
+++ b/translate/translator.py
@@ -0,0 +1,12 @@
+"""
+Voice service abstract class
+"""
+
+
+class Translator(object):
+ # please use https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes to specify language
+ def translate(self, query: str, from_lang: str = "", to_lang: str = "en") -> str:
+ """
+ Translate text from one language to another
+ """
+ raise NotImplementedError
diff --git a/voice/ali/ali_api.py b/voice/ali/ali_api.py
new file mode 100644
index 0000000..cac0c8c
--- /dev/null
+++ b/voice/ali/ali_api.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+"""
+Author: chazzjimel
+Email: chazzjimel@gmail.com
+wechat:cheung-z-x
+
+Description:
+
+"""
+
+import json
+import time
+import requests
+import datetime
+import hashlib
+import hmac
+import base64
+import urllib.parse
+import uuid
+
+from common.log import logger
+from common.tmp_dir import TmpDir
+
+
+def text_to_speech_aliyun(url, text, appkey, token):
+ """
+ 使用阿里云的文本转语音服务将文本转换为语音。
+
+ 参数:
+ - url (str): 阿里云文本转语音服务的端点URL。
+ - text (str): 要转换为语音的文本。
+ - appkey (str): 您的阿里云appkey。
+ - token (str): 阿里云API的认证令牌。
+
+ 返回值:
+ - str: 成功时输出音频文件的路径,否则为None。
+ """
+ headers = {
+ "Content-Type": "application/json",
+ }
+
+ data = {
+ "text": text,
+ "appkey": appkey,
+ "token": token,
+ "format": "wav"
+ }
+
+ response = requests.post(url, headers=headers, data=json.dumps(data))
+
+ if response.status_code == 200 and response.headers['Content-Type'] == 'audio/mpeg':
+ output_file = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
+
+ with open(output_file, 'wb') as file:
+ file.write(response.content)
+ logger.debug(f"音频文件保存成功,文件名:{output_file}")
+ else:
+ logger.debug("响应状态码: {}".format(response.status_code))
+ logger.debug("响应内容: {}".format(response.text))
+ output_file = None
+
+ return output_file
+
+
+class AliyunTokenGenerator:
+ """
+ 用于生成阿里云服务认证令牌的类。
+
+ 属性:
+ - access_key_id (str): 您的阿里云访问密钥ID。
+ - access_key_secret (str): 您的阿里云访问密钥秘密。
+ """
+
+ def __init__(self, access_key_id, access_key_secret):
+ self.access_key_id = access_key_id
+ self.access_key_secret = access_key_secret
+
+ def sign_request(self, parameters):
+ """
+ 为阿里云服务签名请求。
+
+ 参数:
+ - parameters (dict): 请求的参数字典。
+
+ 返回值:
+ - str: 请求的签名签章。
+ """
+ # 将参数按照字典顺序排序
+ sorted_params = sorted(parameters.items())
+
+ # 构造待签名的查询字符串
+ canonicalized_query_string = ''
+ for (k, v) in sorted_params:
+ canonicalized_query_string += '&' + self.percent_encode(k) + '=' + self.percent_encode(v)
+
+ # 构造用于签名的字符串
+ string_to_sign = 'GET&%2F&' + self.percent_encode(canonicalized_query_string[1:]) # 使用GET方法
+
+ # 使用HMAC算法计算签名
+ h = hmac.new((self.access_key_secret + "&").encode('utf-8'), string_to_sign.encode('utf-8'), hashlib.sha1)
+ signature = base64.encodebytes(h.digest()).strip()
+
+ return signature
+
+ def percent_encode(self, encode_str):
+ """
+ 对字符串进行百分比编码。
+
+ 参数:
+ - encode_str (str): 要编码的字符串。
+
+ 返回值:
+ - str: 编码后的字符串。
+ """
+ encode_str = str(encode_str)
+ res = urllib.parse.quote(encode_str, '')
+ res = res.replace('+', '%20')
+ res = res.replace('*', '%2A')
+ res = res.replace('%7E', '~')
+ return res
+
+ def get_token(self):
+ """
+ 获取阿里云服务的令牌。
+
+ 返回值:
+ - str: 获取到的令牌。
+ """
+ # 设置请求参数
+ params = {
+ 'Format': 'JSON',
+ 'Version': '2019-02-28',
+ 'AccessKeyId': self.access_key_id,
+ 'SignatureMethod': 'HMAC-SHA1',
+ 'Timestamp': datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
+ 'SignatureVersion': '1.0',
+ 'SignatureNonce': str(uuid.uuid4()), # 使用uuid生成唯一的随机数
+ 'Action': 'CreateToken',
+ 'RegionId': 'cn-shanghai'
+ }
+
+ # 计算签名
+ signature = self.sign_request(params)
+ params['Signature'] = signature
+
+ # 构造请求URL
+ url = 'http://nls-meta.cn-shanghai.aliyuncs.com/?' + urllib.parse.urlencode(params)
+
+ # 发送请求
+ response = requests.get(url)
+
+ return response.text
diff --git a/voice/ali/ali_voice.py b/voice/ali/ali_voice.py
new file mode 100644
index 0000000..79a9aaa
--- /dev/null
+++ b/voice/ali/ali_voice.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+"""
+Author: chazzjimel
+Email: chazzjimel@gmail.com
+wechat:cheung-z-x
+
+Description:
+ali voice service
+
+"""
+import json
+import os
+import re
+import time
+
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from voice.voice import Voice
+from voice.ali.ali_api import AliyunTokenGenerator
+from voice.ali.ali_api import text_to_speech_aliyun
+from config import conf
+
+
+class AliVoice(Voice):
+ def __init__(self):
+ """
+ 初始化AliVoice类,从配置文件加载必要的配置。
+ """
+ try:
+ curdir = os.path.dirname(__file__)
+ config_path = os.path.join(curdir, "config.json")
+ with open(config_path, "r") as fr:
+ config = json.load(fr)
+ self.token = None
+ self.token_expire_time = 0
+ # 默认复用阿里云千问的 access_key 和 access_secret
+ self.api_url = config.get("api_url")
+ self.app_key = config.get("app_key")
+ self.access_key_id = conf().get("qwen_access_key_id") or config.get("access_key_id")
+ self.access_key_secret = conf().get("qwen_access_key_secret") or config.get("access_key_secret")
+ except Exception as e:
+ logger.warn("AliVoice init failed: %s, ignore " % e)
+
+ def textToVoice(self, text):
+ """
+ 将文本转换为语音文件。
+
+ :param text: 要转换的文本。
+ :return: 返回一个Reply对象,其中包含转换得到的语音文件或错误信息。
+ """
+ # 清除文本中的非中文、非英文和非基本字符
+ text = re.sub(r'[^\u4e00-\u9fa5\u3040-\u30FF\uAC00-\uD7AFa-zA-Z0-9'
+ r'äöüÄÖÜáéíóúÁÉÍÓÚàèìòùÀÈÌÒÙâêîôûÂÊÎÔÛçÇñÑ,。!?,.]', '', text)
+ # 提取有效的token
+ token_id = self.get_valid_token()
+ fileName = text_to_speech_aliyun(self.api_url, text, self.app_key, token_id)
+ if fileName:
+ logger.info("[Ali] textToVoice text={} voice file name={}".format(text, fileName))
+ reply = Reply(ReplyType.VOICE, fileName)
+ else:
+ reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
+ return reply
+
+ def get_valid_token(self):
+ """
+ 获取有效的阿里云token。
+
+ :return: 返回有效的token字符串。
+ """
+ current_time = time.time()
+ if self.token is None or current_time >= self.token_expire_time:
+ get_token = AliyunTokenGenerator(self.access_key_id, self.access_key_secret)
+ token_str = get_token.get_token()
+ token_data = json.loads(token_str)
+ self.token = token_data["Token"]["Id"]
+ # 将过期时间减少一小段时间(例如5分钟),以避免在边界条件下的过期
+ self.token_expire_time = token_data["Token"]["ExpireTime"] - 300
+ logger.debug(f"新获取的阿里云token:{self.token}")
+ else:
+ logger.debug("使用缓存的token")
+ return self.token
diff --git a/voice/ali/config.json.template b/voice/ali/config.json.template
new file mode 100644
index 0000000..6a4aaa9
--- /dev/null
+++ b/voice/ali/config.json.template
@@ -0,0 +1,6 @@
+{
+ "api_url": "https://nls-gateway-cn-shanghai.aliyuncs.com/stream/v1/tts",
+ "app_key": "",
+ "access_key_id": "",
+ "access_key_secret": ""
+}
\ No newline at end of file
diff --git a/voice/audio_convert.py b/voice/audio_convert.py
new file mode 100644
index 0000000..18fe3c2
--- /dev/null
+++ b/voice/audio_convert.py
@@ -0,0 +1,133 @@
+import shutil
+import wave
+
+from common.log import logger
+
+try:
+ import pysilk
+except ImportError:
+ logger.warn("import pysilk failed, wechaty voice message will not be supported.")
+
+from pydub import AudioSegment
+
+sil_supports = [8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率
+
+
+def find_closest_sil_supports(sample_rate):
+ """
+ 找到最接近的支持的采样率
+ """
+ if sample_rate in sil_supports:
+ return sample_rate
+ closest = 0
+ mindiff = 9999999
+ for rate in sil_supports:
+ diff = abs(rate - sample_rate)
+ if diff < mindiff:
+ closest = rate
+ mindiff = diff
+ return closest
+
+
+def get_pcm_from_wav(wav_path):
+ """
+ 从 wav 文件中读取 pcm
+
+ :param wav_path: wav 文件路径
+ :returns: pcm 数据
+ """
+ wav = wave.open(wav_path, "rb")
+ return wav.readframes(wav.getnframes())
+
+
+def any_to_mp3(any_path, mp3_path):
+ """
+ 把任意格式转成mp3文件
+ """
+ if any_path.endswith(".mp3"):
+ shutil.copy2(any_path, mp3_path)
+ return
+ if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
+ sil_to_wav(any_path, any_path)
+ any_path = mp3_path
+ audio = AudioSegment.from_file(any_path)
+ audio.export(mp3_path, format="mp3")
+
+
+def any_to_wav(any_path, wav_path):
+ """
+ 把任意格式转成wav文件
+ """
+ if any_path.endswith(".wav"):
+ shutil.copy2(any_path, wav_path)
+ return
+ if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
+ return sil_to_wav(any_path, wav_path)
+ audio = AudioSegment.from_file(any_path)
+ audio.export(wav_path, format="wav")
+
+
+def any_to_sil(any_path, sil_path):
+ """
+ 把任意格式转成sil文件
+ """
+ if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
+ shutil.copy2(any_path, sil_path)
+ return 10000
+ audio = AudioSegment.from_file(any_path)
+ rate = find_closest_sil_supports(audio.frame_rate)
+ # Convert to PCM_s16
+ pcm_s16 = audio.set_sample_width(2)
+ pcm_s16 = pcm_s16.set_frame_rate(rate)
+ wav_data = pcm_s16.raw_data
+ silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate)
+ with open(sil_path, "wb") as f:
+ f.write(silk_data)
+ return audio.duration_seconds * 1000
+
+
+def any_to_amr(any_path, amr_path):
+ """
+ 把任意格式转成amr文件
+ """
+ if any_path.endswith(".amr"):
+ shutil.copy2(any_path, amr_path)
+ return
+ if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
+ raise NotImplementedError("Not support file type: {}".format(any_path))
+ audio = AudioSegment.from_file(any_path)
+ audio = audio.set_frame_rate(8000) # only support 8000
+ audio.export(amr_path, format="amr")
+ return audio.duration_seconds * 1000
+
+
+def sil_to_wav(silk_path, wav_path, rate: int = 24000):
+ """
+ silk 文件转 wav
+ """
+ wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate)
+ with open(wav_path, "wb") as f:
+ f.write(wav_data)
+
+
+def split_audio(file_path, max_segment_length_ms=60000):
+ """
+ 分割音频文件
+ """
+ audio = AudioSegment.from_file(file_path)
+ audio_length_ms = len(audio)
+ if audio_length_ms <= max_segment_length_ms:
+ return audio_length_ms, [file_path]
+ segments = []
+ for start_ms in range(0, audio_length_ms, max_segment_length_ms):
+ end_ms = min(audio_length_ms, start_ms + max_segment_length_ms)
+ segment = audio[start_ms:end_ms]
+ segments.append(segment)
+ file_prefix = file_path[: file_path.rindex(".")]
+ format = file_path[file_path.rindex(".") + 1 :]
+ files = []
+ for i, segment in enumerate(segments):
+ path = f"{file_prefix}_{i+1}" + f".{format}"
+ segment.export(path, format=format)
+ files.append(path)
+ return audio_length_ms, files
diff --git a/voice/azure/azure_voice.py b/voice/azure/azure_voice.py
new file mode 100644
index 0000000..b5884ed
--- /dev/null
+++ b/voice/azure/azure_voice.py
@@ -0,0 +1,95 @@
+"""
+azure voice service
+"""
+import json
+import os
+import time
+
+import azure.cognitiveservices.speech as speechsdk
+from langid import classify
+
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from common.tmp_dir import TmpDir
+from config import conf
+from voice.voice import Voice
+
+"""
+Azure voice
+主目录设置文件中需填写azure_voice_api_key和azure_voice_region
+
+查看可用的 voice: https://speech.microsoft.com/portal/voicegallery
+
+"""
+
+
+class AzureVoice(Voice):
+ def __init__(self):
+ try:
+ curdir = os.path.dirname(__file__)
+ config_path = os.path.join(curdir, "config.json")
+ config = None
+ if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件
+ config = {
+ "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", # 识别不出时的默认语音
+ "auto_detect": True, # 是否自动检测语言
+ "speech_synthesis_zh": "zh-CN-XiaozhenNeural",
+ "speech_synthesis_en": "en-US-JacobNeural",
+ "speech_synthesis_ja": "ja-JP-AoiNeural",
+ "speech_synthesis_ko": "ko-KR-SoonBokNeural",
+ "speech_synthesis_de": "de-DE-LouisaNeural",
+ "speech_synthesis_fr": "fr-FR-BrigitteNeural",
+ "speech_synthesis_es": "es-ES-LaiaNeural",
+ "speech_recognition_language": "zh-CN",
+ }
+ with open(config_path, "w") as fw:
+ json.dump(config, fw, indent=4)
+ else:
+ with open(config_path, "r") as fr:
+ config = json.load(fr)
+ self.config = config
+ self.api_key = conf().get("azure_voice_api_key")
+ self.api_region = conf().get("azure_voice_region")
+ self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
+ self.speech_config.speech_synthesis_voice_name = self.config["speech_synthesis_voice_name"]
+ self.speech_config.speech_recognition_language = self.config["speech_recognition_language"]
+ except Exception as e:
+ logger.warn("AzureVoice init failed: %s, ignore " % e)
+
+ def voiceToText(self, voice_file):
+ audio_config = speechsdk.AudioConfig(filename=voice_file)
+ speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
+ result = speech_recognizer.recognize_once()
+ if result.reason == speechsdk.ResultReason.RecognizedSpeech:
+ logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text))
+ reply = Reply(ReplyType.TEXT, result.text)
+ else:
+ cancel_details = result.cancellation_details
+ logger.error("[Azure] voiceToText error, result={}, errordetails={}".format(result, cancel_details.error_details))
+ reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
+ return reply
+
+ def textToVoice(self, text):
+ if self.config.get("auto_detect"):
+ lang = classify(text)[0]
+ key = "speech_synthesis_" + lang
+ if key in self.config:
+ logger.info("[Azure] textToVoice auto detect language={}, voice={}".format(lang, self.config[key]))
+ self.speech_config.speech_synthesis_voice_name = self.config[key]
+ else:
+ self.speech_config.speech_synthesis_voice_name = self.config["speech_synthesis_voice_name"]
+ else:
+ self.speech_config.speech_synthesis_voice_name = self.config["speech_synthesis_voice_name"]
+ # Avoid the same filename under multithreading
+ fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
+ audio_config = speechsdk.AudioConfig(filename=fileName)
+ speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
+ result = speech_synthesizer.speak_text(text)
+ if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
+ logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName))
+ reply = Reply(ReplyType.VOICE, fileName)
+ else:
+ cancel_details = result.cancellation_details
+ logger.error("[Azure] textToVoice error, result={}, errordetails={}".format(result, cancel_details.error_details))
+ reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
+ return reply
diff --git a/voice/azure/config.json.template b/voice/azure/config.json.template
new file mode 100644
index 0000000..8f3f546
--- /dev/null
+++ b/voice/azure/config.json.template
@@ -0,0 +1,12 @@
+{
+ "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
+ "auto_detect": true,
+ "speech_synthesis_zh": "zh-CN-YunxiNeural",
+ "speech_synthesis_en": "en-US-JacobNeural",
+ "speech_synthesis_ja": "ja-JP-AoiNeural",
+ "speech_synthesis_ko": "ko-KR-SoonBokNeural",
+ "speech_synthesis_de": "de-DE-LouisaNeural",
+ "speech_synthesis_fr": "fr-FR-BrigitteNeural",
+ "speech_synthesis_es": "es-ES-LaiaNeural",
+ "speech_recognition_language": "zh-CN"
+}
diff --git a/voice/baidu/README.md b/voice/baidu/README.md
new file mode 100644
index 0000000..d4628a1
--- /dev/null
+++ b/voice/baidu/README.md
@@ -0,0 +1,55 @@
+## 说明
+百度语音识别与合成参数说明
+百度语音依赖,经常会出现问题,可能就是缺少依赖:
+pip install baidu-aip
+pip install pydub
+pip install pysilk
+还有ffmpeg,不同系统安装方式不同
+
+系统中收到的语音文件为mp3格式(wx)或者sil格式(wxy),如果要识别需要转换为pcm格式,转换后的文件为16k采样率,单声道,16bit的pcm文件
+发送时又需要(wx)转换为mp3格式,转换后的文件为16k采样率,单声道,16bit的pcm文件,(wxy)转换为sil格式,还要计算声音长度,发送时需要带上声音长度
+这些事情都在audio_convert.py中封装了,直接调用即可
+
+
+参数说明
+识别参数
+https://ai.baidu.com/ai-doc/SPEECH/Vk38lxily
+合成参数
+https://ai.baidu.com/ai-doc/SPEECH/Gk38y8lzk
+
+## 使用说明
+分两个地方配置
+
+1、对于def voiceToText(self, filename)函数中调用的百度语音识别API,中接口调用asr(参数)这个配置见CHATGPT-ON-WECHAT工程目录下的`config.json`文件和config.py文件。
+参数 可需 描述
+app_id 必填 应用的APPID
+api_key 必填 应用的APIKey
+secret_key 必填 应用的SecretKey
+dev_pid 必填 语言选择,填写语言对应的dev_pid值
+
+2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。
+参数 可需 描述
+tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节
+lan 必填 固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh
+spd 选填 语速,取值0-15,默认为5中语速
+pit 选填 音调,取值0-15,默认为5中语调
+vol 选填 音量,取值0-15,默认为5中音量(取值为0时为音量最小值,并非为无声)
+per(基础音库) 选填 度小宇=1,度小美=0,度逍遥(基础)=3,度丫丫=4
+per(精品音库) 选填 度逍遥(精品)=5003,度小鹿=5118,度博文=106,度小童=110,度小萌=111,度米朵=103,度小娇=5
+aue 选填 3为mp3格式(默认); 4为pcm-16k;5为pcm-8k;6为wav(内容同pcm-16k); 注意aue=4或者6是语音识别要求的格式,但是音频内容不是语音识别要求的自然人发音,所以识别效果会受影响。
+
+关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。
+### 配置文件
+
+将文件夹中`config.json.template`复制为`config.json`。
+
+``` json
+ {
+ "lang": "zh",
+ "ctp": 1,
+ "spd": 5,
+ "pit": 5,
+ "vol": 5,
+ "per": 0
+ }
+```
\ No newline at end of file
diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py
new file mode 100644
index 0000000..fbf53ce
--- /dev/null
+++ b/voice/baidu/baidu_voice.py
@@ -0,0 +1,94 @@
+"""
+baidu voice service
+"""
+import json
+import os
+import time
+
+from aip import AipSpeech
+
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from common.tmp_dir import TmpDir
+from config import conf
+from voice.audio_convert import get_pcm_from_wav
+from voice.voice import Voice
+
+"""
+ 百度的语音识别API.
+ dev_pid:
+ - 1936: 普通话远场
+ - 1536:普通话(支持简单的英文识别)
+ - 1537:普通话(纯中文识别)
+ - 1737:英语
+ - 1637:粤语
+ - 1837:四川话
+ 要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号,
+ 之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key
+ 然后在 config.json 中填入这两个值, 以及 app_id, dev_pid
+ """
+
+
+class BaiduVoice(Voice):
+ def __init__(self):
+ try:
+ curdir = os.path.dirname(__file__)
+ config_path = os.path.join(curdir, "config.json")
+ bconf = None
+ if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件
+ bconf = {"lang": "zh", "ctp": 1, "spd": 5, "pit": 5, "vol": 5, "per": 0}
+ with open(config_path, "w") as fw:
+ json.dump(bconf, fw, indent=4)
+ else:
+ with open(config_path, "r") as fr:
+ bconf = json.load(fr)
+
+ self.app_id = str(conf().get("baidu_app_id"))
+ self.api_key = str(conf().get("baidu_api_key"))
+ self.secret_key = str(conf().get("baidu_secret_key"))
+ self.dev_id = conf().get("baidu_dev_pid")
+ self.lang = bconf["lang"]
+ self.ctp = bconf["ctp"]
+ self.spd = bconf["spd"]
+ self.pit = bconf["pit"]
+ self.vol = bconf["vol"]
+ self.per = bconf["per"]
+
+ self.client = AipSpeech(self.app_id, self.api_key, self.secret_key)
+ except Exception as e:
+ logger.warn("BaiduVoice init failed: %s, ignore " % e)
+
+ def voiceToText(self, voice_file):
+ # 识别本地文件
+ logger.debug("[Baidu] voice file name={}".format(voice_file))
+ pcm = get_pcm_from_wav(voice_file)
+ res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id})
+ if res["err_no"] == 0:
+ logger.info("百度语音识别到了:{}".format(res["result"]))
+ text = "".join(res["result"])
+ reply = Reply(ReplyType.TEXT, text)
+ else:
+ logger.info("百度语音识别出错了: {}".format(res["err_msg"]))
+ if res["err_msg"] == "request pv too much":
+ logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费")
+ reply = Reply(ReplyType.ERROR, "百度语音识别出错了;{0}".format(res["err_msg"]))
+ return reply
+
+ def textToVoice(self, text):
+ result = self.client.synthesis(
+ text,
+ self.lang,
+ self.ctp,
+ {"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per},
+ )
+ if not isinstance(result, dict):
+ # Avoid the same filename under multithreading
+ fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
+ with open(fileName, "wb") as f:
+ f.write(result)
+ logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName))
+ reply = Reply(ReplyType.VOICE, fileName)
+ else:
+ logger.error("[Baidu] textToVoice error={}".format(result))
+ reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
+ return reply
diff --git a/voice/baidu/config.json.template b/voice/baidu/config.json.template
new file mode 100644
index 0000000..19e812f
--- /dev/null
+++ b/voice/baidu/config.json.template
@@ -0,0 +1,8 @@
+{
+ "lang": "zh",
+ "ctp": 1,
+ "spd": 5,
+ "pit": 5,
+ "vol": 5,
+ "per": 0
+}
diff --git a/voice/elevent/elevent_voice.py b/voice/elevent/elevent_voice.py
new file mode 100644
index 0000000..15936ab
--- /dev/null
+++ b/voice/elevent/elevent_voice.py
@@ -0,0 +1,33 @@
+import time
+
+from elevenlabs import set_api_key,generate
+
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from common.tmp_dir import TmpDir
+from voice.voice import Voice
+from config import conf
+
+XI_API_KEY = conf().get("xi_api_key")
+set_api_key(XI_API_KEY)
+name = conf().get("xi_voice_id")
+
+class ElevenLabsVoice(Voice):
+
+ def __init__(self):
+ pass
+
+ def voiceToText(self, voice_file):
+ pass
+
+ def textToVoice(self, text):
+ audio = generate(
+ text=text,
+ voice=name,
+ model='eleven_multilingual_v1'
+ )
+ fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
+ with open(fileName, "wb") as f:
+ f.write(audio)
+ logger.info("[ElevenLabs] textToVoice text={} voice file name={}".format(text, fileName))
+ return Reply(ReplyType.VOICE, fileName)
\ No newline at end of file
diff --git a/voice/factory.py b/voice/factory.py
new file mode 100644
index 0000000..ed80758
--- /dev/null
+++ b/voice/factory.py
@@ -0,0 +1,45 @@
+"""
+voice factory
+"""
+
+
+def create_voice(voice_type):
+ """
+ create a voice instance
+ :param voice_type: voice type code
+ :return: voice instance
+ """
+ if voice_type == "baidu":
+ from voice.baidu.baidu_voice import BaiduVoice
+
+ return BaiduVoice()
+ elif voice_type == "google":
+ from voice.google.google_voice import GoogleVoice
+
+ return GoogleVoice()
+ elif voice_type == "openai":
+ from voice.openai.openai_voice import OpenaiVoice
+
+ return OpenaiVoice()
+ elif voice_type == "pytts":
+ from voice.pytts.pytts_voice import PyttsVoice
+
+ return PyttsVoice()
+ elif voice_type == "azure":
+ from voice.azure.azure_voice import AzureVoice
+
+ return AzureVoice()
+ elif voice_type == "elevenlabs":
+ from voice.elevent.elevent_voice import ElevenLabsVoice
+
+ return ElevenLabsVoice()
+
+ elif voice_type == "linkai":
+ from voice.linkai.linkai_voice import LinkAIVoice
+
+ return LinkAIVoice()
+ elif voice_type == "ali":
+ from voice.ali.ali_voice import AliVoice
+
+ return AliVoice()
+ raise RuntimeError
diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py
new file mode 100644
index 0000000..6dcadad
--- /dev/null
+++ b/voice/google/google_voice.py
@@ -0,0 +1,47 @@
+"""
+google voice service
+"""
+
+import time
+
+import speech_recognition
+from gtts import gTTS
+
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from common.tmp_dir import TmpDir
+from voice.voice import Voice
+
+
+class GoogleVoice(Voice):
+ recognizer = speech_recognition.Recognizer()
+
+ def __init__(self):
+ pass
+
+ def voiceToText(self, voice_file):
+ with speech_recognition.AudioFile(voice_file) as source:
+ audio = self.recognizer.record(source)
+ try:
+ text = self.recognizer.recognize_google(audio, language="zh-CN")
+ logger.info("[Google] voiceToText text={} voice file name={}".format(text, voice_file))
+ reply = Reply(ReplyType.TEXT, text)
+ except speech_recognition.UnknownValueError:
+ reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
+ except speech_recognition.RequestError as e:
+ reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e))
+ finally:
+ return reply
+
+ def textToVoice(self, text):
+ try:
+ # Avoid the same filename under multithreading
+ mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
+ tts = gTTS(text=text, lang="zh")
+ tts.save(mp3File)
+ logger.info("[Google] textToVoice text={} voice file name={}".format(text, mp3File))
+ reply = Reply(ReplyType.VOICE, mp3File)
+ except Exception as e:
+ reply = Reply(ReplyType.ERROR, str(e))
+ finally:
+ return reply
diff --git a/voice/linkai/linkai_voice.py b/voice/linkai/linkai_voice.py
new file mode 100644
index 0000000..074c9fd
--- /dev/null
+++ b/voice/linkai/linkai_voice.py
@@ -0,0 +1,83 @@
+"""
+google voice service
+"""
+import random
+import requests
+from voice import audio_convert
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+from voice.voice import Voice
+from common import const
+import os
+import datetime
+
+class LinkAIVoice(Voice):
+ def __init__(self):
+ pass
+
+ def voiceToText(self, voice_file):
+ logger.debug("[LinkVoice] voice file name={}".format(voice_file))
+ try:
+ url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/audio/transcriptions"
+ headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
+ model = None
+ if not conf().get("text_to_voice") or conf().get("voice_to_text") == "openai":
+ model = const.WHISPER_1
+ if voice_file.endswith(".amr"):
+ try:
+ mp3_file = os.path.splitext(voice_file)[0] + ".mp3"
+ audio_convert.any_to_mp3(voice_file, mp3_file)
+ voice_file = mp3_file
+ except Exception as e:
+ logger.warn(f"[LinkVoice] amr file transfer failed, directly send amr voice file: {format(e)}")
+ file = open(voice_file, "rb")
+ file_body = {
+ "file": file
+ }
+ data = {
+ "model": model
+ }
+ res = requests.post(url, files=file_body, headers=headers, data=data, timeout=(5, 60))
+ if res.status_code == 200:
+ text = res.json().get("text")
+ else:
+ res_json = res.json()
+ logger.error(f"[LinkVoice] voiceToText error, status_code={res.status_code}, msg={res_json.get('message')}")
+ return None
+ reply = Reply(ReplyType.TEXT, text)
+ logger.info(f"[LinkVoice] voiceToText success, text={text}, file name={voice_file}")
+ except Exception as e:
+ logger.error(e)
+ return None
+ return reply
+
+ def textToVoice(self, text):
+ try:
+ url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/audio/speech"
+ headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
+ model = const.TTS_1
+ if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
+ model = conf().get("text_to_voice_model") or const.TTS_1
+ data = {
+ "model": model,
+ "input": text,
+ "voice": conf().get("tts_voice_id"),
+ "app_code": conf().get("linkai_app_code")
+ }
+ res = requests.post(url, headers=headers, json=data, timeout=(5, 120))
+ if res.status_code == 200:
+ tmp_file_name = "tmp/" + datetime.datetime.now().strftime('%Y%m%d%H%M%S') + str(random.randint(0, 1000)) + ".mp3"
+ with open(tmp_file_name, 'wb') as f:
+ f.write(res.content)
+ reply = Reply(ReplyType.VOICE, tmp_file_name)
+ logger.info(f"[LinkVoice] textToVoice success, input={text}, model={model}, voice_id={data.get('voice')}")
+ return reply
+ else:
+ res_json = res.json()
+ logger.error(f"[LinkVoice] textToVoice error, status_code={res.status_code}, msg={res_json.get('message')}")
+ return None
+ except Exception as e:
+ logger.error(e)
+ # reply = Reply(ReplyType.ERROR, "遇到了一点小问题,请稍后再问我吧")
+ return None
diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py
new file mode 100644
index 0000000..767353e
--- /dev/null
+++ b/voice/openai/openai_voice.py
@@ -0,0 +1,57 @@
+"""
+google voice service
+"""
+import json
+
+import openai
+
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from config import conf
+from voice.voice import Voice
+import requests
+from common import const
+import datetime, random
+
+class OpenaiVoice(Voice):
+ def __init__(self):
+ openai.api_key = conf().get("open_ai_api_key")
+
+ def voiceToText(self, voice_file):
+ logger.debug("[Openai] voice file name={}".format(voice_file))
+ try:
+ file = open(voice_file, "rb")
+ result = openai.Audio.transcribe("whisper-1", file)
+ text = result["text"]
+ reply = Reply(ReplyType.TEXT, text)
+ logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file))
+ except Exception as e:
+ reply = Reply(ReplyType.ERROR, "我暂时还无法听清您的语音,请稍后再试吧~")
+ finally:
+ return reply
+
+
+ def textToVoice(self, text):
+ try:
+ api_base = conf().get("open_ai_api_base") or "https://api.openai.com/v1"
+ url = f'{api_base}/audio/speech'
+ headers = {
+ 'Authorization': 'Bearer ' + conf().get("open_ai_api_key"),
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ 'model': conf().get("text_to_voice_model") or const.TTS_1,
+ 'input': text,
+ 'voice': conf().get("tts_voice_id") or "alloy"
+ }
+ response = requests.post(url, headers=headers, json=data)
+ file_name = "tmp/" + datetime.datetime.now().strftime('%Y%m%d%H%M%S') + str(random.randint(0, 1000)) + ".mp3"
+ logger.debug(f"[OPENAI] text_to_Voice file_name={file_name}, input={text}")
+ with open(file_name, 'wb') as f:
+ f.write(response.content)
+ logger.info(f"[OPENAI] text_to_Voice success")
+ reply = Reply(ReplyType.VOICE, file_name)
+ except Exception as e:
+ logger.error(e)
+ reply = Reply(ReplyType.ERROR, "遇到了一点小问题,请稍后再问我吧")
+ return reply
diff --git a/voice/pytts/pytts_voice.py b/voice/pytts/pytts_voice.py
new file mode 100644
index 0000000..bd70086
--- /dev/null
+++ b/voice/pytts/pytts_voice.py
@@ -0,0 +1,64 @@
+"""
+pytts voice service (offline)
+"""
+
+import os
+import sys
+import time
+
+import pyttsx3
+
+from bridge.reply import Reply, ReplyType
+from common.log import logger
+from common.tmp_dir import TmpDir
+from voice.voice import Voice
+
+
+class PyttsVoice(Voice):
+ engine = pyttsx3.init()
+
+ def __init__(self):
+ # 语速
+ self.engine.setProperty("rate", 125)
+ # 音量
+ self.engine.setProperty("volume", 1.0)
+ if sys.platform == "win32":
+ for voice in self.engine.getProperty("voices"):
+ if "Chinese" in voice.name:
+ self.engine.setProperty("voice", voice.id)
+ else:
+ self.engine.setProperty("voice", "zh")
+ # If the problem of espeak is fixed, using runAndWait() and remove this startLoop()
+ # TODO: check if this is work on win32
+ self.engine.startLoop(useDriverLoop=False)
+
+ def textToVoice(self, text):
+ try:
+ # Avoid the same filename under multithreading
+ wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
+ wavFile = TmpDir().path() + wavFileName
+ logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile))
+
+ self.engine.save_to_file(text, wavFile)
+
+ if sys.platform == "win32":
+ self.engine.runAndWait()
+ else:
+ # In ubuntu, runAndWait do not really wait until the file created.
+ # It will return once the task queue is empty, but the task is still running in coroutine.
+ # And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this.
+ # If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file.
+ # self.engine.runAndWait()
+
+ # Before espeak fix this problem, we iterate the generator and control the waiting by ourself.
+ # But this is not the canonical way to use it, for example if the file already exists it also cannot wait.
+ self.engine.iterate()
+ while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()):
+ time.sleep(0.1)
+
+ reply = Reply(ReplyType.VOICE, wavFile)
+
+ except Exception as e:
+ reply = Reply(ReplyType.ERROR, str(e))
+ finally:
+ return reply
diff --git a/voice/voice.py b/voice/voice.py
new file mode 100644
index 0000000..1ca199b
--- /dev/null
+++ b/voice/voice.py
@@ -0,0 +1,17 @@
+"""
+Voice service abstract class
+"""
+
+
+class Voice(object):
+ def voiceToText(self, voice_file):
+ """
+ Send voice to voice service and get text
+ """
+ raise NotImplementedError
+
+ def textToVoice(self, text):
+ """
+ Send text to voice service and get voice
+ """
+ raise NotImplementedError