feat: update
This commit is contained in:
parent
5b5989bc02
commit
ddd972be2d
|
@ -0,0 +1,163 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
|
@ -0,0 +1,20 @@
|
|||
default_stages: [ pre-commit ]
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ['--profile', 'black']
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [ --fix ]
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: ['--line-length', '120']
|
|
@ -0,0 +1,13 @@
|
|||
FROM python:3.10-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
COPY app.py .
|
||||
COPY src/ ./src/
|
||||
|
||||
RUN pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple --no-cache-dir -r requirements.txt
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["python", "app.py"]
|
|
@ -0,0 +1,22 @@
|
|||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from src.routers import chat, upload
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="YuanBao API Proxy", version="1.0.0")
|
||||
|
||||
app.include_router(chat.router)
|
||||
app.include_router(upload.router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|
|
@ -0,0 +1,59 @@
|
|||
import base64
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
|
||||
base_url = "http://localhost:8000/v1/"
|
||||
|
||||
hy_source = "web"
|
||||
hy_user = "" # 替换为你的用户ID
|
||||
hy_token = "" # 替换为你的token
|
||||
|
||||
agent_id = "naQivTmsDa"
|
||||
chat_id = "" # 可选,如果不提供会自动创建
|
||||
|
||||
# upload,可选
|
||||
url = base_url + "upload"
|
||||
|
||||
file_name = "example.png"
|
||||
with open(file_name, "rb") as f:
|
||||
file_data = base64.b64encode(f.read()).decode("utf-8")
|
||||
data = {
|
||||
"agent_id": agent_id,
|
||||
"hy_source": hy_source,
|
||||
"hy_user": hy_user,
|
||||
"file": {
|
||||
"file_name": file_name,
|
||||
"file_data": file_data ,
|
||||
"file_type": "image", # 只能是 image 或 doc
|
||||
},
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {hy_token}"}
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
if response.status_code == 200:
|
||||
print("File uploaded successfully:", response.json())
|
||||
multimedia = [response.json()]
|
||||
else:
|
||||
print("File upload failed:", response.status_code, response.text)
|
||||
multimedia = []
|
||||
print(multimedia)
|
||||
|
||||
# chat
|
||||
client = OpenAI(base_url=base_url, api_key=hy_token)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-v3",
|
||||
messages=[{"role": "user", "content": "这是什么?"}],
|
||||
stream=True,
|
||||
extra_body={
|
||||
"hy_source": hy_source,
|
||||
"hy_user": hy_user,
|
||||
"agent_id": agent_id,
|
||||
"chat_id": chat_id,
|
||||
"should_remove_conversation": False,
|
||||
"multimedia": multimedia,
|
||||
},
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk.choices[0].delta.content or "")
|
Binary file not shown.
After Width: | Height: | Size: 164 KiB |
|
@ -0,0 +1,195 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APPID = "wx12b75947931a04ec"
|
||||
HEADERS = {
|
||||
"x-token": "",
|
||||
"x-instance-id": "1",
|
||||
"x-language": "zh-CN",
|
||||
"x-requested-with": "XMLHttpRequest",
|
||||
"x-operationsystem": "win",
|
||||
"x-channel": "10014",
|
||||
"x-id": "",
|
||||
"x-product": "bot",
|
||||
"x-appversion": "1.8.1",
|
||||
"x-source": "web",
|
||||
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36 Edg/133.0.0.0 app_lang/zh-CN product_id/TM_Product_App app_instance_id/2 os_version/10.0.19045 app_short_version/1.8.1 package_type/publish_release app/tencent_yuanbao app_full_version/1.8.1.610 app_theme/system app_version/1.8.1 os_name/windows c_district/0",
|
||||
"x-a3": "c2ac2b24fe3303043553b2b0300019319312",
|
||||
}
|
||||
|
||||
TIMEOUT = 30
|
||||
|
||||
|
||||
class YuanbaoLogin:
|
||||
def __init__(self):
|
||||
self.headers = {"User-Agent": HEADERS["user-agent"]}
|
||||
self.uuid: Optional[str] = None
|
||||
self.wx_code: Optional[str] = None
|
||||
self.qrcode_path = "qrcode.jpg"
|
||||
|
||||
def get_qrcode(self) -> bool:
|
||||
"""获取微信登录二维码并显示
|
||||
|
||||
Returns:
|
||||
bool: 获取二维码是否成功
|
||||
"""
|
||||
try:
|
||||
url = "https://open.weixin.qq.com/connect/qrconnect"
|
||||
params = {
|
||||
"appid": APPID,
|
||||
"scope": "snsapi_login",
|
||||
"redirect_uri": "https://yuanbao.tencent.com/desktop-redirect.html?&&bindType=wechat_login",
|
||||
"state": "",
|
||||
"login_type": "jssdk",
|
||||
"self_redirect": "true",
|
||||
"styletype": "",
|
||||
"sizetype": "",
|
||||
"bgcolor": "",
|
||||
"rst": "",
|
||||
"href": "",
|
||||
}
|
||||
response = requests.get(url, params=params, headers=self.headers, timeout=TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
qrcodes = soup.find_all("img", class_="js_qrcode_img web_qrcode_img")
|
||||
|
||||
if not qrcodes:
|
||||
logger.error("未找到二维码元素")
|
||||
return False
|
||||
|
||||
qrcode_src = qrcodes[0].get("src")
|
||||
self.uuid = qrcode_src.split("/")[-1]
|
||||
|
||||
qrcode_url = f"https://open.weixin.qq.com{qrcode_src}"
|
||||
qrcode_response = requests.get(qrcode_url, headers=self.headers, timeout=TIMEOUT)
|
||||
qrcode_response.raise_for_status()
|
||||
|
||||
with open(self.qrcode_path, "wb") as f:
|
||||
f.write(qrcode_response.content)
|
||||
|
||||
logger.info("二维码已保存到 %s,请扫描登录", self.qrcode_path)
|
||||
return True
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("获取二维码失败: %s", str(e))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("处理二维码时出错: %s", str(e))
|
||||
return False
|
||||
|
||||
def check_scan_status(self) -> bool:
|
||||
"""检查二维码扫描状态
|
||||
|
||||
Returns:
|
||||
bool: 是否成功获取到微信授权码
|
||||
"""
|
||||
if not self.uuid:
|
||||
logger.error("UUID未初始化,请先获取二维码")
|
||||
return False
|
||||
|
||||
url = "https://lp.open.weixin.qq.com/connect/l/qrconnect"
|
||||
params = {"uuid": self.uuid, "_": int(time.time() * 1000)}
|
||||
pattern = r"window\.wx_errcode=(\d*);window\.wx_code='(.*)';"
|
||||
|
||||
self.wx_code = ""
|
||||
try:
|
||||
for attempt in range(20):
|
||||
response = requests.get(url, params=params, headers=self.headers, timeout=TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
logger.debug("扫码状态响应: %s", response.text)
|
||||
match = re.search(pattern, response.text)
|
||||
if match:
|
||||
errcode, self.wx_code = match.groups()
|
||||
if self.wx_code:
|
||||
logger.info("用户已确认登录")
|
||||
return True
|
||||
|
||||
if errcode == "403":
|
||||
logger.warning("用户拒绝授权")
|
||||
return False
|
||||
elif errcode == "402":
|
||||
logger.warning("二维码已过期")
|
||||
return False
|
||||
elif errcode == "408":
|
||||
logger.info("等待用户扫描 (%d/20)", attempt + 1)
|
||||
elif errcode == "404":
|
||||
logger.info("用户已扫码,等待确认 (%d/20)", attempt + 1)
|
||||
params["last"] = errcode
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
logger.warning("扫码超时")
|
||||
return False
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("检查扫码状态失败: %s", str(e))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("处理扫码状态时出错: %s", str(e))
|
||||
return False
|
||||
|
||||
def login(self) -> Dict[str, str]:
|
||||
"""使用微信授权码登录元宝平台
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 登录成功后的cookies
|
||||
"""
|
||||
if not self.wx_code:
|
||||
logger.error("微信授权码未获取,无法登录")
|
||||
return {}
|
||||
|
||||
url = "https://yuanbao.tencent.com/api/joint/login"
|
||||
data = {"type": "wx", "jsCode": self.wx_code, "appid": APPID}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=data, headers=HEADERS, timeout=TIMEOUT)
|
||||
response.raise_for_status()
|
||||
cookies = response.cookies.get_dict()
|
||||
|
||||
if cookies:
|
||||
logger.info("登录成功")
|
||||
return cookies
|
||||
else:
|
||||
logger.warning("登录响应中没有cookies")
|
||||
return {}
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error("登录请求失败: %s", str(e))
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error("登录过程中出错: %s", str(e))
|
||||
return {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
yuanbao_login = YuanbaoLogin()
|
||||
|
||||
if yuanbao_login.get_qrcode():
|
||||
if yuanbao_login.check_scan_status():
|
||||
login_cookies = yuanbao_login.login()
|
||||
if login_cookies:
|
||||
print("登录成功,获取到的cookies:")
|
||||
print(login_cookies)
|
||||
else:
|
||||
print("登录失败,未获取到cookies")
|
||||
else:
|
||||
print("微信扫码失败或超时")
|
||||
else:
|
||||
print("获取二维码失败")
|
||||
|
||||
if os.path.exists(yuanbao_login.qrcode_path):
|
||||
try:
|
||||
os.remove(yuanbao_login.qrcode_path)
|
||||
except Exception as e:
|
||||
logger.warning("清理二维码文件失败: %s", str(e))
|
|
@ -0,0 +1,159 @@
|
|||
# YuanBao-Free-API ✨
|
||||
|
||||
一个允许您通过 OpenAI 兼容接口访问腾讯元宝的服务。
|
||||
|
||||
## ✨ 核心特性
|
||||
|
||||
✅ **完整兼容 OpenAI API 规范**
|
||||
🚀 **支持主流元宝大模型**(DeepSeek/HunYuan系列)
|
||||
⚡️ **流式输出 & 网络搜索功能**
|
||||
🖼️ **支持上传图片或文件**
|
||||
📦 **开箱即用的部署方案**(本地/Docker)
|
||||
|
||||
## ⚠️ 使用须知
|
||||
|
||||
- 本项目仅限**学习研究用途**
|
||||
- 请严格遵守腾讯元宝的[使用条款](https://yuanbao.tencent.com/)
|
||||
- `hy_token` 有时效性,过期需重新获取
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 环境准备
|
||||
```bash
|
||||
git clone https://github.com/chenwr727/yuanbao-free-api.git
|
||||
cd yuanbao-free-api
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 🖥️ 服务端部署
|
||||
|
||||
### 本地运行
|
||||
```bash
|
||||
# 服务地址:http://localhost:8000
|
||||
python app.py
|
||||
```
|
||||
|
||||
### Docker部署
|
||||
```bash
|
||||
# 构建镜像
|
||||
docker build -t yuanbao-free-api .
|
||||
|
||||
# 运行容器
|
||||
docker run -d -p 8000:8000 --name yuanbao-api yuanbao-free-api
|
||||
```
|
||||
|
||||
## 📡 客户端调用
|
||||
|
||||
### 认证参数获取
|
||||
#### 手动获取
|
||||

|
||||
1. 访问[腾讯元宝](https://yuanbao.tencent.com/)
|
||||
2. 打开开发者工具(F12)
|
||||
3. 捕获对话请求获取:
|
||||
- Cookie中的 `hy_user` 和 `hy_token`
|
||||
- 请求体中的 `agent_id`
|
||||
|
||||
#### 自动获取
|
||||
```bash
|
||||
# 扫码登录后自动输出认证参数
|
||||
python get_cookies.py
|
||||
```
|
||||
|
||||
### API调用示例
|
||||
```python
|
||||
import base64
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
|
||||
base_url = "http://localhost:8000/v1/"
|
||||
|
||||
hy_source = "web"
|
||||
hy_user = "" # 替换为你的用户ID
|
||||
hy_token = "" # 替换为你的token
|
||||
|
||||
agent_id = "naQivTmsDa"
|
||||
chat_id = "" # 可选,如果不提供会自动创建
|
||||
|
||||
# upload,可选
|
||||
url = base_url + "upload"
|
||||
|
||||
file_name = "example.png"
|
||||
with open(file_name, "rb") as f:
|
||||
file_data = base64.b64encode(f.read()).decode("utf-8")
|
||||
data = {
|
||||
"agent_id": agent_id,
|
||||
"hy_source": hy_source,
|
||||
"hy_user": hy_user,
|
||||
"file": {
|
||||
"file_name": file_name,
|
||||
"file_data": file_data ,
|
||||
"file_type": "image", # 只能是 image 或 doc
|
||||
},
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {hy_token}"}
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
if response.status_code == 200:
|
||||
print("File uploaded successfully:", response.json())
|
||||
multimedia = [response.json()]
|
||||
else:
|
||||
print("File upload failed:", response.status_code, response.text)
|
||||
multimedia = []
|
||||
print(multimedia)
|
||||
|
||||
# chat
|
||||
client = OpenAI(base_url=base_url, api_key=hy_token)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-v3",
|
||||
messages=[{"role": "user", "content": "这是什么?"}],
|
||||
stream=True,
|
||||
extra_body={
|
||||
"hy_source": hy_source,
|
||||
"hy_user": hy_user,
|
||||
"agent_id": agent_id,
|
||||
"chat_id": chat_id,
|
||||
"should_remove_conversation": False,
|
||||
"multimedia": multimedia,
|
||||
},
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk.choices[0].delta.content or "")
|
||||
```
|
||||
|
||||
## 🧠 支持模型
|
||||
|
||||
| 模型名称 | 特性说明 |
|
||||
|----------------------|-----------------------------|
|
||||
| deepseek-v3 | 深度求索 V3 基础模型 |
|
||||
| deepseek-r1 | 深度求索 R1 增强模型 |
|
||||
| deepseek-v3-search | 深度求索 V3 模型(带搜索功能)|
|
||||
| deepseek-r1-search | 深度求索 R1 模型(带搜索功能)|
|
||||
| hunyuan | 腾讯混元基础模型 |
|
||||
| hunyuan-t1 | 腾讯混元 T1 模型 |
|
||||
| hunyuan-search | 腾讯混元模型(带搜索功能) |
|
||||
| hunyuan-t1-search | 腾讯混元 T1 模型(带搜索功能)|
|
||||
|
||||
## 🌟 应用案例
|
||||
|
||||
[FinVizAI](https://github.com/chenwr727/FinVizAI) 实现多步骤金融分析工作流:
|
||||
- 实时资讯搜索分析
|
||||
- 市场趋势数据集成
|
||||
- 结构化报告生成
|
||||
|
||||
[CodexReel](https://github.com/chenwr727/CodexReel) 一个基于 AI 的智能视频生成平台:
|
||||
- 支持文章链接或主题文本输入(支持联网搜索)
|
||||
- 自动完成内容理解与脚本生成
|
||||
- 素材匹配、语音合成与视频剪辑一体化输出
|
||||
|
||||
## 📜 开源协议
|
||||
|
||||
MIT License © 2025
|
||||
|
||||
## 🤝 参与贡献
|
||||
|
||||
欢迎通过以下方式参与项目:
|
||||
1. 提交Issue报告问题
|
||||
2. 创建Pull Request贡献代码
|
||||
3. 分享你的集成案例
|
|
@ -0,0 +1,7 @@
|
|||
beautifulsoup4
|
||||
fastapi
|
||||
httpx
|
||||
pydantic
|
||||
Requests
|
||||
sse_starlette
|
||||
uvicorn
|
|
@ -0,0 +1,20 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class CHUNK_TYPE(str, Enum):
|
||||
STATUS = "status"
|
||||
SEARCH_WITH_TEXT = "search_with_text"
|
||||
REASONER = "reasoner"
|
||||
TEXT = "text"
|
||||
|
||||
|
||||
MODEL_MAPPING = {
|
||||
"deepseek-v3": {"model": "deep_seek_v3", "support_functions": None},
|
||||
"deepseek-r1": {"model": "deep_seek", "support_functions": None},
|
||||
"deepseek-v3-search": {"model": "deep_seek_v3", "support_functions": ["supportInternetSearch"]},
|
||||
"deepseek-r1-search": {"model": "deep_seek", "support_functions": ["supportInternetSearch"]},
|
||||
"hunyuan": {"model": "hunyuan_gpt_175B_0404", "support_functions": None},
|
||||
"hunyuan-t1": {"model": "hunyuan_t1", "support_functions": None},
|
||||
"hunyuan-search": {"model": "hunyuan_gpt_175B_0404", "support_functions": ["supportInternetSearch"]},
|
||||
"hunyuan-t1-search": {"model": "hunyuan_t1", "support_functions": ["supportInternetSearch"]},
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from src.utils.common import generate_headers
|
||||
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_authorized_headers(
|
||||
request: Request,
|
||||
authorization: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
):
|
||||
if not authorization or not authorization.credentials:
|
||||
raise HTTPException(status_code=401, detail="need token")
|
||||
|
||||
token = authorization.credentials
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
data = await request.json()
|
||||
elif "multipart/form-data" in content_type:
|
||||
form = await request.form()
|
||||
data = dict(form)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Unsupported Content-Type")
|
||||
|
||||
headers = generate_headers(data, token)
|
||||
return headers
|
|
@ -0,0 +1,44 @@
|
|||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from src.dependencies.auth import get_authorized_headers
|
||||
from src.schemas.chat import ChatCompletionRequest, YuanBaoChatCompletionRequest
|
||||
from src.services.chat.completion import create_completion_stream
|
||||
from src.services.chat.conversation import create_conversation
|
||||
from src.utils.chat import get_model_info, parse_messages
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(
|
||||
request: ChatCompletionRequest,
|
||||
headers: dict = Depends(get_authorized_headers),
|
||||
):
|
||||
try:
|
||||
if not request.chat_id:
|
||||
request.chat_id = await create_conversation(request.agent_id, headers)
|
||||
logging.info(f"Conversation created with chat_id: {request.chat_id}")
|
||||
|
||||
prompt = parse_messages(request.messages)
|
||||
model_info = get_model_info(request.model)
|
||||
if not model_info:
|
||||
raise HTTPException(status_code=400, detail="invalid model")
|
||||
|
||||
chat_request = YuanBaoChatCompletionRequest(
|
||||
agent_id=request.agent_id,
|
||||
chat_id=request.chat_id,
|
||||
prompt=prompt,
|
||||
chat_model_id=model_info["model"],
|
||||
multimedia=request.multimedia,
|
||||
support_functions=model_info.get("support_functions"),
|
||||
)
|
||||
|
||||
generator = create_completion_stream(chat_request, headers, request.should_remove_conversation)
|
||||
logging.info(f"Streaming chat completion for chat_id: {request.chat_id}")
|
||||
return EventSourceResponse(generator, media_type="text/event-stream")
|
||||
except Exception as e:
|
||||
logging.error(f"Error in chat_completions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
|
@ -0,0 +1,33 @@
|
|||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from src.dependencies.auth import get_authorized_headers
|
||||
from src.schemas.common import Media
|
||||
from src.schemas.upload import UploadFileRequest
|
||||
from src.services.upload.info import get_upload_info
|
||||
from src.services.upload.uploader import upload_file_to_cos
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/v1/upload", response_model=Media)
|
||||
async def upload_file(
|
||||
request: UploadFileRequest,
|
||||
headers: dict = Depends(get_authorized_headers),
|
||||
):
|
||||
try:
|
||||
upload_info = await get_upload_info(request.file.file_name, headers)
|
||||
logging.info("Upload info retrieved successfully")
|
||||
logging.debug(f"upload_info: {upload_info}")
|
||||
|
||||
file_info = await upload_file_to_cos(
|
||||
request.file,
|
||||
upload_info,
|
||||
headers["User-Agent"],
|
||||
)
|
||||
logging.info("File uploaded successfully")
|
||||
logging.debug(f"File uploaded successfully: {file_info}")
|
||||
return file_info
|
||||
except Exception as e:
|
||||
logging.error(f"Error in upload_file: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
|
@ -0,0 +1,63 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from src.const import MODEL_MAPPING
|
||||
from src.schemas.common import Media
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
messages: List[Message]
|
||||
model: str
|
||||
agent_id: str
|
||||
chat_id: Optional[str] = None
|
||||
hy_source: str = "web"
|
||||
hy_user: str
|
||||
should_remove_conversation: bool = False
|
||||
multimedia: List[Media] = []
|
||||
|
||||
@field_validator("messages")
|
||||
def check_messages_not_empty(cls, value):
|
||||
if not value:
|
||||
raise ValueError("messages cannot be an empty list")
|
||||
return value
|
||||
|
||||
@field_validator("model")
|
||||
def validate_model(cls, value):
|
||||
if value not in MODEL_MAPPING:
|
||||
raise ValueError(f"model must be one of {list(MODEL_MAPPING.keys())}")
|
||||
return value
|
||||
|
||||
|
||||
class YuanBaoChatCompletionRequest(BaseModel):
|
||||
agent_id: str
|
||||
chat_id: str
|
||||
prompt: str
|
||||
agent_id: str
|
||||
chat_model_id: str
|
||||
multimedia: List[Media] = []
|
||||
support_functions: Optional[List[str]]
|
||||
|
||||
|
||||
class ChoiceDelta(BaseModel):
|
||||
role: str = "assistant"
|
||||
content: str = ""
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
index: int = 0
|
||||
delta: ChoiceDelta
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str = ""
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
choices: list[Choice]
|
|
@ -0,0 +1,11 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Media(BaseModel):
|
||||
type: str
|
||||
docType: str
|
||||
url: str
|
||||
fileName: str
|
||||
size: int
|
||||
width: int
|
||||
height: int
|
|
@ -0,0 +1,14 @@
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
file_name: str
|
||||
file_data: str
|
||||
file_type: str = Field(..., pattern="^(image|doc)$")
|
||||
|
||||
|
||||
class UploadFileRequest(BaseModel):
|
||||
agent_id: str
|
||||
hy_source: str = "web"
|
||||
hy_user: str
|
||||
file: File
|
|
@ -0,0 +1,58 @@
|
|||
from typing import AsyncGenerator, Dict
|
||||
|
||||
import httpx
|
||||
|
||||
from src.schemas.chat import YuanBaoChatCompletionRequest
|
||||
from src.services.chat.conversation import remove_conversation
|
||||
from src.utils.chat import process_response_stream
|
||||
|
||||
CHAT_URL = "https://yuanbao.tencent.com/api/chat/{}"
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class ChatCompletionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def create_completion_stream(
|
||||
chat_request: YuanBaoChatCompletionRequest,
|
||||
headers: Dict[str, str],
|
||||
should_remove_conversation: bool = False,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
multimedia = [m.model_dump() for m in chat_request.multimedia]
|
||||
body = {
|
||||
"model": "gpt_175B_0404",
|
||||
"prompt": chat_request.prompt,
|
||||
"plugin": "Adaptive",
|
||||
"displayPrompt": chat_request.prompt,
|
||||
"displayPromptType": 1,
|
||||
"options": {"imageIntention": {"needIntentionModel": True, "backendUpdateFlag": 2, "intentionStatus": True}},
|
||||
"multimedia": multimedia,
|
||||
"agentId": chat_request.agent_id,
|
||||
"supportHint": 1,
|
||||
"version": "v2",
|
||||
"chatModelId": chat_request.chat_model_id,
|
||||
}
|
||||
if chat_request.support_functions:
|
||||
body["supportFunctions"] = chat_request.support_functions
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
CHAT_URL.format(chat_request.chat_id),
|
||||
json=body,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
async for chunk in process_response_stream(response, chat_request.chat_id):
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
raise ChatCompletionError(e)
|
||||
|
||||
finally:
|
||||
if should_remove_conversation:
|
||||
await remove_conversation(chat_request.chat_id, headers)
|
|
@ -0,0 +1,55 @@
|
|||
from typing import Dict
|
||||
|
||||
import httpx
|
||||
|
||||
CREATE_URL = "https://yuanbao.tencent.com/api/user/agent/conversation/create"
|
||||
CLEAR_URL = "https://yuanbao.tencent.com/api/user/agent/conversation/v1/clear"
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class ConversationCreationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationRemoveError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def create_conversation(agent_id: str, headers: Dict[str, str], timeout: int = DEFAULT_TIMEOUT) -> str:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(CREATE_URL, json={"agentId": agent_id}, headers=headers, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed. Status code: {response.status_code}, Response: {response.text}")
|
||||
|
||||
try:
|
||||
json_data = response.json()
|
||||
except ValueError:
|
||||
raise Exception(f"Failed to parse response as JSON. Response: {response.text}")
|
||||
|
||||
if "id" not in json_data:
|
||||
raise Exception(f"Failed to find 'id' in response JSON. Response: {response.text}")
|
||||
|
||||
return json_data["id"]
|
||||
|
||||
except Exception as e:
|
||||
raise ConversationCreationError(e)
|
||||
|
||||
|
||||
async def remove_conversation(chat_id: str, headers: Dict[str, str], timeout: int = DEFAULT_TIMEOUT) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
CLEAR_URL,
|
||||
json={"conversationIds": [chat_id], "uiOptions": {"noToast": True}},
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed. Status code: {response.status_code}, Response: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
raise ConversationRemoveError(e)
|
|
@ -0,0 +1,27 @@
|
|||
from typing import Dict
|
||||
|
||||
import httpx
|
||||
|
||||
UPLOAD_URL = "https://yuanbao.tencent.com/api/resource/genUploadInfo"
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class GetUploadInfoError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def get_upload_info(file_name: str, headers: Dict[str, str], timeout: int = DEFAULT_TIMEOUT) -> dict:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
UPLOAD_URL,
|
||||
json={"fileName": file_name, "docFrom": "localDoc", "docOpenId": ""},
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
raise GetUploadInfoError(e)
|
|
@ -0,0 +1,45 @@
|
|||
import base64
|
||||
from typing import Dict
|
||||
|
||||
import httpx
|
||||
|
||||
from src.schemas.upload import File
|
||||
from src.utils.upload import generate_headers, get_file_info
|
||||
|
||||
UPLOAD_HOST = "hunyuan-prod-1258344703.cos.accelerate.myqcloud.com"
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
|
||||
class UploadFileToCosError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def upload_file_to_cos(
|
||||
file: File,
|
||||
upload_info: Dict,
|
||||
user_agent: str,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
) -> Dict:
|
||||
try:
|
||||
url = f"https://{UPLOAD_HOST}{upload_info['location']}"
|
||||
|
||||
file_data_bytes = base64.b64decode(file.file_data)
|
||||
content_length = len(file_data_bytes)
|
||||
headers = generate_headers(file.file_type, content_length, UPLOAD_HOST, upload_info, user_agent)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.put(url, headers=headers, content=file_data_bytes, timeout=timeout)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed. Status code: {response.status_code}, Response: {response.text}")
|
||||
|
||||
return get_file_info(
|
||||
file.file_type,
|
||||
file.file_name,
|
||||
content_length,
|
||||
upload_info["resourceUrl"],
|
||||
response.text,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise UploadFileToCosError(e)
|
|
@ -0,0 +1,65 @@
|
|||
import json
|
||||
import time
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from src.const import CHUNK_TYPE, MODEL_MAPPING
|
||||
from src.schemas.chat import ChatCompletionChunk, Choice, ChoiceDelta, Message
|
||||
|
||||
|
||||
def get_model_info(model_name: str) -> Optional[Dict]:
|
||||
return MODEL_MAPPING.get(model_name.lower(), None)
|
||||
|
||||
|
||||
def parse_messages(messages: List[Message]) -> str:
|
||||
only_user_message = True
|
||||
for m in messages:
|
||||
if m.role == "user":
|
||||
only_user_message = False
|
||||
break
|
||||
if only_user_message:
|
||||
prompt = "\n".join([f"{m.role}: {m.content}" for m in messages])
|
||||
else:
|
||||
prompt = "\n".join([f"{m.content}" for m in messages])
|
||||
return prompt
|
||||
|
||||
|
||||
async def process_response_stream(response: httpx.Response, model_id: str) -> AsyncGenerator[str, None]:
|
||||
def _create_chunk(content: str, finish_reason: Optional[str] = None) -> str:
|
||||
choice_delta = ChoiceDelta(content=content)
|
||||
choice = Choice(delta=choice_delta, finish_reason=finish_reason)
|
||||
chunk = ChatCompletionChunk(created=int(time.time()), model=model_id, choices=[choice])
|
||||
return chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
status = ""
|
||||
start_word = "data: "
|
||||
finish_reason = "stop"
|
||||
async for line in response.aiter_lines():
|
||||
if not line or not line.startswith(start_word):
|
||||
continue
|
||||
data: str = line[len(start_word) :]
|
||||
|
||||
if data == "[DONE]":
|
||||
yield _create_chunk("", finish_reason)
|
||||
yield "[DONE]"
|
||||
break
|
||||
elif data in (CHUNK_TYPE.STATUS, CHUNK_TYPE.SEARCH_WITH_TEXT, CHUNK_TYPE.REASONER, CHUNK_TYPE.TEXT):
|
||||
status = data
|
||||
continue
|
||||
elif not data.startswith("{"):
|
||||
continue
|
||||
|
||||
chunk_data: Dict = json.loads(data)
|
||||
if status == CHUNK_TYPE.TEXT:
|
||||
if chunk_data.get("msg"):
|
||||
yield _create_chunk(f"[{status}]" + chunk_data["msg"])
|
||||
if chunk_data.get("stopReason"):
|
||||
finish_reason = chunk_data["stopReason"]
|
||||
elif status == CHUNK_TYPE.REASONER:
|
||||
yield _create_chunk(f"[{status}]" + chunk_data.get("content", ""))
|
||||
elif status == CHUNK_TYPE.SEARCH_WITH_TEXT:
|
||||
docs = chunk_data.get("docs", [])
|
||||
yield _create_chunk(f"[{status}]" + json.dumps(docs, ensure_ascii=False))
|
||||
if status == CHUNK_TYPE.STATUS:
|
||||
yield _create_chunk(f"[{status}]" + chunk_data.get("msg", ""))
|
|
@ -0,0 +1,11 @@
|
|||
from typing import Dict
|
||||
|
||||
|
||||
def generate_headers(request: dict, token: str) -> Dict[str, str]:
|
||||
return {
|
||||
"Cookie": f"hy_source={request['hy_source']}; hy_user={request['hy_user']}; hy_token={token}",
|
||||
"Origin": "https://yuanbao.tencent.com",
|
||||
"Referer": f"https://yuanbao.tencent.com/chat/{request['agent_id']}",
|
||||
"X-Agentid": request["agent_id"],
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/134.0.0.0 Safari/537.36",
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
import hmac
|
||||
import urllib.parse
|
||||
import xml.etree.ElementTree as ET
|
||||
from hashlib import sha1
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def generate_q_signature(
|
||||
http_method: str,
|
||||
path: str,
|
||||
query_params: Dict[str, str],
|
||||
headers: Dict[str, str],
|
||||
sign_time: str,
|
||||
secret_key: str,
|
||||
) -> str:
|
||||
|
||||
def url_encode(s: str, safe: str = "") -> str:
|
||||
return urllib.parse.quote(s, safe=safe)
|
||||
|
||||
def canonicalize_params(params: Dict[str, str]) -> str:
|
||||
normalized = {k.lower(): v for k, v in params.items()}
|
||||
sorted_items = sorted(normalized.items())
|
||||
return "&".join(f"{url_encode(k)}={url_encode(v)}" for k, v in sorted_items)
|
||||
|
||||
encoded_path = url_encode(path.strip(), safe="/")
|
||||
|
||||
canonical_query_string = canonicalize_params(query_params)
|
||||
|
||||
canonical_headers = canonicalize_params(headers)
|
||||
|
||||
format_string = (
|
||||
f"{http_method.lower()}\n" f"{encoded_path}\n" f"{canonical_query_string}\n" f"{canonical_headers}\n"
|
||||
)
|
||||
format_string_hash = sha1(format_string.encode()).hexdigest()
|
||||
|
||||
string_to_sign = f"sha1\n{sign_time}\n{format_string_hash}\n"
|
||||
sign_key = hmac.new(secret_key.encode(), sign_time.encode(), sha1).hexdigest()
|
||||
signature = hmac.new(sign_key.encode(), string_to_sign.encode(), sha1).hexdigest()
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
def generate_headers(file_type: str, content_length: int, upload_host: str, upload_info: Dict, user_agent: str) -> Dict:
|
||||
content_length = str(content_length)
|
||||
|
||||
headers = {
|
||||
"Host": upload_host,
|
||||
"Content-Length": content_length,
|
||||
"Content-Type": "application/octet-stream",
|
||||
"Origin": "https://yuanbao.tencent.com",
|
||||
"Referer": "https://yuanbao.tencent.com/",
|
||||
"User-Agent": user_agent,
|
||||
"x-cos-security-token": upload_info["encryptToken"],
|
||||
}
|
||||
|
||||
path = upload_info["location"]
|
||||
headers_to_sign = {
|
||||
"content-length": content_length,
|
||||
"host": upload_host,
|
||||
}
|
||||
query_params = {}
|
||||
sign_time = f"{upload_info['startTime']};{upload_info['expiredTime']}"
|
||||
secret_key = upload_info["encryptTmpSecretKey"]
|
||||
|
||||
if file_type == "image":
|
||||
headers["Content-Type"] = "image/png"
|
||||
pic_operations = (
|
||||
'{"is_pic_info":1,"rules":[{"fileid":"%s","rule":"imageMogr2/format/jpg"}]}' % upload_info["location"]
|
||||
)
|
||||
headers["Pic-Operations"] = pic_operations
|
||||
headers_to_sign["pic-operations"] = pic_operations
|
||||
|
||||
signature = generate_q_signature("PUT", path, query_params, headers_to_sign, sign_time, secret_key)
|
||||
|
||||
auth_params = {
|
||||
"q-sign-algorithm": "sha1",
|
||||
"q-ak": upload_info["encryptTmpSecretId"],
|
||||
"q-sign-time": sign_time,
|
||||
"q-key-time": sign_time,
|
||||
"q-header-list": ";".join(headers_to_sign.keys()),
|
||||
"q-url-param-list": "",
|
||||
"q-signature": signature,
|
||||
}
|
||||
|
||||
headers["Authorization"] = "&".join([f"{k}={v}" for k, v in auth_params.items()])
|
||||
return headers
|
||||
|
||||
|
||||
def get_file_info(file_type: str, file_name: str, content_length, url: str, xml_data: str) -> Dict:
|
||||
file_info = {
|
||||
"type": file_type,
|
||||
"docType": file_type,
|
||||
"url": url,
|
||||
"fileName": file_name,
|
||||
"size": 0,
|
||||
"width": 0,
|
||||
"height": 0,
|
||||
}
|
||||
|
||||
if file_type == "image":
|
||||
root = ET.fromstring(xml_data)
|
||||
process_result = root.find("ProcessResults/Object")
|
||||
|
||||
file_info["size"] = int(process_result.find("Size").text)
|
||||
file_info["width"] = int(process_result.find("Width").text)
|
||||
file_info["height"] = int(process_result.find("Height").text)
|
||||
else:
|
||||
file_info["size"] = content_length
|
||||
return file_info
|
|
@ -11,6 +11,10 @@
|
|||
#define M_PI 3.14159265358979323846264338327950288
|
||||
#endif /* M_PI */
|
||||
|
||||
#ifndef TWO_PI
|
||||
#define TWO_PI 2 * 3.14159265358979323846264338327950288
|
||||
#endif /* M_PI */
|
||||
|
||||
#ifndef EPSILON
|
||||
#define EPSILON 1e-6f
|
||||
#endif /* EPSILON */
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/* Implementation of https://github.com/argonautcode/animal-proc-anim.git */
|
||||
#include <math.h>
|
||||
#include <math.h> /**< absf */
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h> /**< rand */
|
||||
|
||||
|
@ -70,6 +70,52 @@ void OnButton(swwWindow* o, swwButton btn, int pressed)
|
|||
void OnScroll(swwWindow* o, float offset)
|
||||
{}
|
||||
|
||||
// Simplify the angle to be in the range [0, 2pi)
|
||||
static float SimplifyAngle(float angle)
|
||||
{
|
||||
while (angle >= TWO_PI) {
|
||||
angle -= TWO_PI;
|
||||
}
|
||||
|
||||
while (angle < 0) {
|
||||
angle += TWO_PI;
|
||||
}
|
||||
|
||||
return angle;
|
||||
}
|
||||
|
||||
// i.e. How many radians do you need to turn the angle to match the anchor?
|
||||
static float RelativeAngleDiff(float angle, float anchor)
|
||||
{
|
||||
// Since angles are represented by values in [0, 2pi), it's helpful to rotate
|
||||
// the coordinate space such that PI is at the anchor. That way we don't have
|
||||
// to worry about the "seam" between 0 and 2pi.
|
||||
angle = SimplifyAngle(angle + PI - anchor);
|
||||
anchor = PI;
|
||||
|
||||
return anchor - angle;
|
||||
}
|
||||
|
||||
// Constrain the vector to be at a certain range of the anchor
|
||||
static Vector2f ConstrainDistance(Vector2f pos, Vector2f anchor, float constraint)
|
||||
{
|
||||
return Vector2f_Add(anchor, Vector2f_Sub(pos, anchor).setMag(constraint));
|
||||
}
|
||||
|
||||
// Constrain the angle to be within a certain range of the anchor
|
||||
static float ConstrainAngle(float angle, float anchor, float constraint)
|
||||
{
|
||||
if (absf(RelativeAngleDiff(angle, anchor)) <= constraint) {
|
||||
return SimplifyAngle(angle);
|
||||
}
|
||||
|
||||
if (RelativeAngleDiff(angle, anchor) > constraint) {
|
||||
return SimplifyAngle(anchor - constraint);
|
||||
}
|
||||
|
||||
return SimplifyAngle(anchor + constraint);
|
||||
}
|
||||
|
||||
CSTLVector(Point2fVector, Point2f);
|
||||
CSTLVector(FloatVector, float);
|
||||
|
||||
|
@ -124,42 +170,47 @@ void Chain_Resolve(Chain* o, Point2f pos)
|
|||
Vechtor2f_Heading(Vector2f_Sub(Vector_At(&o->joints, i - 1), Vector_At(&o->joints, i)));
|
||||
Vector_Set(&o->angles, i,
|
||||
constrainAngle(cur_angle, Vector_At(&o->angles, i - 1), o->angle_constraint));
|
||||
Vector_Set(&o->joints, i,
|
||||
Vector2f_Sub(Vector_At(&o->joints, i - 1),
|
||||
Vector2f_FromAngle(Vector_At(&o->angles., i)).setMag(linkSize)));
|
||||
Vector_Set(
|
||||
&o->joints, i,
|
||||
Vector2f_Sub(Vector_At(&o->joints, i - 1),
|
||||
Vector2f_FromAngle(Vector_At(&o->angles., i)).setMag(o->link_size)));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
void Chain_FabrikResolve(Chain *o, PVector pos, PVector anchor) {
|
||||
// Forward pass
|
||||
joints.set(0, pos);
|
||||
for (int i = 1; i < joints.size(); i++) {
|
||||
joints.set(i, constrainDistance(joints.get(i), joints.get(i-1), linkSize));
|
||||
}
|
||||
void Chain_FabrikResolve(Chain* o, PVector pos, PVector anchor)
|
||||
{
|
||||
// Forward pass
|
||||
Vector_Set(&o->joints, 0, pos);
|
||||
for (int i = 1; i < Vector_Size(&o->joints); ++i) {
|
||||
Vector_Set(&o->joints, i,
|
||||
ConstrainDistance(Vector_At(&o->joints, i),
|
||||
Vector_At(&o->joints, i - 1, o->link_size)));
|
||||
}
|
||||
|
||||
// Backward pass
|
||||
joints.set(joints.size() - 1, anchor);
|
||||
for (int i = joints.size() - 2; i >= 0; i--) {
|
||||
joints.set(i, constrainDistance(joints.get(i), joints.get(i+1), linkSize));
|
||||
}
|
||||
// Backward pass
|
||||
Vector_Set(&o->joints, Vector_Size(&o->joints) - 1, anchor);
|
||||
for (int i = Vector_Size(&o->joints) - 2; i >= 0; --i) {
|
||||
Vector_Set(&o->joints, i,
|
||||
ConstrainDistance(Vector_At(&o->joints, i), Vector_At(&o->joints, i + 1),
|
||||
o->link_size));
|
||||
}
|
||||
}
|
||||
|
||||
void display() {
|
||||
strokeWeight(8);
|
||||
stroke(255);
|
||||
for (int i = 0; i < joints.size() - 1; i++) {
|
||||
PVector startJoint = joints.get(i);
|
||||
PVector endJoint = joints.get(i + 1);
|
||||
line(startJoint.x, startJoint.y, endJoint.x, endJoint.y);
|
||||
}
|
||||
void Display()
|
||||
{
|
||||
strokeWeight(8);
|
||||
stroke(255);
|
||||
for (int i = 0; i < joints.size() - 1; i++) {
|
||||
PVector startJoint = joints.get(i);
|
||||
PVector endJoint = joints.get(i + 1);
|
||||
line(startJoint.x, startJoint.y, endJoint.x, endJoint.y);
|
||||
}
|
||||
|
||||
fill(42, 44, 53);
|
||||
for (PVector joint : joints) {
|
||||
ellipse(joint.x, joint.y, 32, 32);
|
||||
}
|
||||
fill(42, 44, 53);
|
||||
for (PVector joint : joints) {
|
||||
ellipse(joint.x, joint.y, 32, 32);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
/* Chain */
|
||||
void Scene1(swwWindow* o)
|
||||
|
|
Loading…
Reference in New Issue