mirror of
https://github.com/lanqian528/chat2api.git
synced 2026-06-12 21:00:46 +08:00
v1.4.0 renew 'retry' logic
This commit is contained in:
parent
c064b39dec
commit
992ded1ecc
2
.github/workflows/build_docker.yml
vendored
2
.github/workflows/build_docker.yml
vendored
@ -37,7 +37,7 @@ jobs:
|
||||
images: lanqian528/chat2api
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=raw,value=v1.3.7
|
||||
type=raw,value=v1.4.0
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
|
||||
11
chat2api.py
11
chat2api.py
@ -58,17 +58,20 @@ async def to_send_conversation(request_data, req_token):
|
||||
raise HTTPException(status_code=500, detail="Server error")
|
||||
|
||||
|
||||
async def process(request_data, req_token):
|
||||
chat_service = await to_send_conversation(request_data, req_token)
|
||||
await chat_service.prepare_send_conversation()
|
||||
res = await chat_service.send_conversation()
|
||||
return chat_service, res
|
||||
|
||||
@app.post(f"/{api_prefix}/v1/chat/completions" if api_prefix else "/v1/chat/completions")
|
||||
async def send_conversation(request: Request, req_token: str = Depends(oauth2_scheme)):
|
||||
try:
|
||||
request_data = await request.json()
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"})
|
||||
|
||||
chat_service = await async_retry(to_send_conversation, request_data, req_token)
|
||||
chat_service, res = await async_retry(process, request_data, req_token)
|
||||
try:
|
||||
await chat_service.prepare_send_conversation()
|
||||
res = await chat_service.send_conversation()
|
||||
if isinstance(res, types.AsyncGeneratorType):
|
||||
background = BackgroundTask(chat_service.close_client)
|
||||
return StreamingResponse(res, media_type="text/event-stream", background=background)
|
||||
|
||||
@ -1,21 +1,18 @@
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import types
|
||||
import uuid
|
||||
|
||||
import websockets
|
||||
from fastapi import HTTPException
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from api.files import get_image_size, get_file_extension, determine_file_use_case
|
||||
from api.models import model_proxy
|
||||
from chatgpt.authorization import get_req_token, verify_token
|
||||
from chatgpt.chatFormat import api_messages_to_chat, stream_response, wss_stream_response, format_not_stream_response
|
||||
from chatgpt.chatFormat import api_messages_to_chat, stream_response, format_not_stream_response, head_process_response
|
||||
from chatgpt.chatLimit import check_is_limit, handle_request_limit
|
||||
from chatgpt.proofofWork import get_config, get_dpl, get_answer_token, get_requirements_token
|
||||
from chatgpt.turnstile import process_turnstile
|
||||
from chatgpt.wssClient import token2wss, set_wss
|
||||
from utils.Client import Client
|
||||
from utils.Logger import logger
|
||||
from utils.config import proxy_url_list, chatgpt_base_url_list, arkose_token_url_list, history_disabled, pow_difficulty, \
|
||||
@ -68,12 +65,6 @@ class ChatService:
|
||||
self.arkose_token_url = random.choice(arkose_token_url_list) if arkose_token_url_list else None
|
||||
|
||||
self.s = Client(proxy=self.proxy_url)
|
||||
self.ws = None
|
||||
if conversation_only:
|
||||
self.wss_mode = False
|
||||
self.wss_url = None
|
||||
else:
|
||||
self.wss_mode, self.wss_url = await token2wss(self.req_token)
|
||||
|
||||
self.oai_device_id = str(uuid.uuid4())
|
||||
self.persona = None
|
||||
@ -133,21 +124,6 @@ class ChatService:
|
||||
else:
|
||||
self.req_model = "text-davinci-002-render-sha"
|
||||
|
||||
async def get_wss_url(self):
|
||||
url = f'{self.base_url}/register-websocket'
|
||||
headers = self.base_headers.copy()
|
||||
r = await self.s.post(url, headers=headers, data='', timeout=5)
|
||||
try:
|
||||
if r.status_code == 200:
|
||||
resp = r.json()
|
||||
logger.info(f'register-websocket response:{resp}')
|
||||
wss_url = resp.get('wss_url')
|
||||
return wss_url
|
||||
raise Exception(r.text)
|
||||
except Exception as e:
|
||||
logger.error(f"get_wss_url error: {str(e)}")
|
||||
raise HTTPException(status_code=r.status_code, detail=f"Failed to get wss url: {str(e)}")
|
||||
|
||||
async def get_chat_requirements(self):
|
||||
if conversation_only:
|
||||
return None
|
||||
@ -308,14 +284,6 @@ class ChatService:
|
||||
|
||||
async def send_conversation(self):
|
||||
try:
|
||||
try:
|
||||
if self.wss_mode:
|
||||
if not self.wss_url:
|
||||
self.wss_url = await self.get_wss_url()
|
||||
self.ws = await websockets.connect(self.wss_url, ping_interval=None, subprotocols=["json.reliable.webpubsub.azure.v1"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to wss: {str(e)}", )
|
||||
raise HTTPException(status_code=502, detail="Failed to connect to wss")
|
||||
url = f'{self.base_url}/conversation'
|
||||
stream = self.data.get("stream", False)
|
||||
r = await self.s.post_stream(url, headers=self.chat_headers, json=self.chat_request, timeout=10,
|
||||
@ -338,40 +306,23 @@ class ChatService:
|
||||
raise HTTPException(status_code=r.status_code, detail=detail)
|
||||
|
||||
content_type = r.headers.get("Content-Type", "")
|
||||
if "text/event-stream" in content_type and stream:
|
||||
await set_wss(self.req_token, False)
|
||||
return stream_response(self, r.aiter_lines(), self.resp_model, self.max_tokens)
|
||||
elif "text/event-stream" in content_type and not stream:
|
||||
await set_wss(self.req_token, False)
|
||||
return await format_not_stream_response(
|
||||
stream_response(self, r.aiter_lines(), self.resp_model, self.max_tokens), self.prompt_tokens,
|
||||
self.max_tokens, self.resp_model)
|
||||
if "text/event-stream" in content_type:
|
||||
res, start = await head_process_response(r.aiter_lines())
|
||||
if not start:
|
||||
raise HTTPException(status_code=403, detail="Our systems have detected unusual activity coming from your system. Please try again later.")
|
||||
if stream:
|
||||
return stream_response(self, res, self.resp_model, self.max_tokens)
|
||||
else:
|
||||
return await format_not_stream_response(
|
||||
stream_response(self, res, self.resp_model, self.max_tokens), self.prompt_tokens,
|
||||
self.max_tokens, self.resp_model)
|
||||
elif "application/json" in content_type:
|
||||
rtext = await r.atext()
|
||||
resp = json.loads(rtext)
|
||||
self.wss_url = resp.get('wss_url')
|
||||
conversation_id = resp.get('conversation_id')
|
||||
await set_wss(self.req_token, True, self.wss_url)
|
||||
logger.info(f"next wss_url: {self.wss_url}")
|
||||
if not self.ws:
|
||||
try:
|
||||
self.ws = await websockets.connect(self.wss_url, ping_interval=None, subprotocols=["json.reliable.webpubsub.azure.v1"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to wss: {str(e)}", )
|
||||
raise HTTPException(status_code=502, detail="Failed to connect to wss")
|
||||
wss_r = wss_stream_response(self.ws, conversation_id)
|
||||
try:
|
||||
if stream and isinstance(wss_r, types.AsyncGeneratorType):
|
||||
return stream_response(self, wss_r, self.resp_model, self.max_tokens)
|
||||
else:
|
||||
return await format_not_stream_response(
|
||||
stream_response(self, wss_r, self.resp_model, self.max_tokens), self.prompt_tokens,
|
||||
self.max_tokens, self.resp_model)
|
||||
finally:
|
||||
if not isinstance(wss_r, types.AsyncGeneratorType):
|
||||
await self.ws.close()
|
||||
raise HTTPException(status_code=r.status_code, detail=resp)
|
||||
else:
|
||||
raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type")
|
||||
rtext = await r.atext()
|
||||
raise HTTPException(status_code=r.status_code, detail=rtext)
|
||||
except HTTPException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
except Exception as e:
|
||||
|
||||
@ -105,6 +105,25 @@ async def wss_stream_response(websocket, conversation_id):
|
||||
continue
|
||||
|
||||
|
||||
async def head_process_response(response):
|
||||
async for chunk in response:
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk.startswith("data: {"):
|
||||
chunk_old_data = json.loads(chunk[6:])
|
||||
message = chunk_old_data.get("message", {})
|
||||
if not message and "error" in chunk_old_data:
|
||||
return response, False
|
||||
role = message.get('author', {}).get('role')
|
||||
if role == 'user' or role == 'system':
|
||||
continue
|
||||
|
||||
status = message.get("status")
|
||||
if status == "in_progress":
|
||||
return response, True
|
||||
return response, False
|
||||
|
||||
|
||||
|
||||
async def stream_response(service, response, model, max_tokens):
|
||||
chat_id = f"chatcmpl-{''.join(random.choice(string.ascii_letters + string.digits) for _ in range(29))}"
|
||||
system_fingerprint_list = model_system_fingerprint.get(model, None)
|
||||
@ -116,9 +135,27 @@ async def stream_response(service, response, model, max_tokens):
|
||||
last_message_id = None
|
||||
last_content_type = None
|
||||
last_recipient = None
|
||||
start = False
|
||||
start = True
|
||||
end = False
|
||||
|
||||
chunk_new_data = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": ""},
|
||||
"logprobs": None,
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
if system_fingerprint:
|
||||
chunk_new_data["system_fingerprint"] = system_fingerprint
|
||||
yield f"data: {json.dumps(chunk_new_data)}\n\n"
|
||||
|
||||
async for chunk in response:
|
||||
chunk = chunk.decode("utf-8")
|
||||
# chunk = 'data: {"message": null, "conversation_id": "38b8bfcf-9912-45db-a48e-b62fb585c855", "error": "Our systems have detected unusual activity coming from your system. Please try again later."}'
|
||||
@ -237,21 +274,8 @@ async def stream_response(service, response, model, max_tokens):
|
||||
last_message_id = message_id
|
||||
if not end and not delta.get("content"):
|
||||
delta = {"role": "assistant", "content": ""}
|
||||
chunk_new_data = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason
|
||||
}
|
||||
],
|
||||
"system_fingerprint": system_fingerprint
|
||||
}
|
||||
chunk_new_data["choices"][0]["delta"] = delta
|
||||
chunk_new_data["choices"][0]["finish_reason"] = finish_reason
|
||||
if not service.history_disabled:
|
||||
chunk_new_data.update({
|
||||
"message_id": message_id,
|
||||
|
||||
@ -46,7 +46,7 @@ proxy_url_list = proxy_url.split(',') if proxy_url else []
|
||||
user_agents_list = user_agents.split(',') if user_agents else []
|
||||
|
||||
logger.info("-" * 60)
|
||||
logger.info("Chat2Api v1.3.7 | https://github.com/lanqian528/chat2api")
|
||||
logger.info("Chat2Api v1.4.0 | https://github.com/lanqian528/chat2api")
|
||||
logger.info("-" * 60)
|
||||
logger.info("Environment variables:")
|
||||
logger.info("API_PREFIX: " + str(api_prefix))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user