mirror of
https://github.com/lanqian528/chat2api.git
synced 2026-06-13 21:02:46 +08:00
feat support not stream dalle and tools
fix headers
This commit is contained in:
parent
d884a54d50
commit
cf2cff84de
@ -22,13 +22,12 @@
|
||||
> - [x] 配置 BASE_URL
|
||||
> - [x] 重试次数设置
|
||||
> - [x] ArkoseToken
|
||||
> - [x] GPT4.0 画图、工具 (仅流式,测试阶段)
|
||||
> - [x] GPT4.0 画图、工具 (测试中,可能有bug)
|
||||
> - [x] 使用 RefreshToken 代替 AccessToken
|
||||
> - [x] 反向代理 UI (http://127.0.0.1:5005)
|
||||
> - [x] 反向代理 UI (http://127.0.0.1:5005, 不支持登录)
|
||||
>
|
||||
> TODO
|
||||
>
|
||||
> - [ ] GPT4.0 画图、工具(非流式)
|
||||
> - [ ] GPTs
|
||||
|
||||
## Deploy
|
||||
|
||||
@ -6,8 +6,7 @@ import uuid
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.chat_completions import num_tokens_from_messages, model_system_fingerprint, model_proxy, \
|
||||
split_tokens_from_content
|
||||
from api.chat_completions import num_tokens_from_messages, model_system_fingerprint, model_proxy, num_tokens_from_content
|
||||
from chatgpt.proofofwork import calc_proof_token
|
||||
from utils.Client import Client
|
||||
from utils.Logger import Logger
|
||||
@ -27,6 +26,7 @@ async def stream_response(service, response, model, max_tokens):
|
||||
last_recipient = None
|
||||
end = False
|
||||
message_id = None
|
||||
all_text = ""
|
||||
async for chunk in response.aiter_lines():
|
||||
chunk = chunk.decode("utf-8")
|
||||
if end:
|
||||
@ -94,7 +94,10 @@ async def stream_response(service, response, model, max_tokens):
|
||||
Logger.debug(f"asset_pointer: {asset_pointer}")
|
||||
image_download_url = await service.get_image_download_url(asset_pointer)
|
||||
Logger.debug(f"image_download_url: {image_download_url}")
|
||||
delta = {"content": f"\n```\n\n"}
|
||||
if image_download_url:
|
||||
delta = {"content": f"\n```\n\n"}
|
||||
else:
|
||||
delta = {"content": f"\n```\nFailed to load the image.\n"}
|
||||
elif not message.get("end_turn") or not message.get("metadata").get("finish_details"):
|
||||
message_id = None
|
||||
len_last_content = 0
|
||||
@ -105,6 +108,7 @@ async def stream_response(service, response, model, max_tokens):
|
||||
end = True
|
||||
else:
|
||||
continue
|
||||
all_text += delta.get("content", "")
|
||||
chunk_new_data = {
|
||||
"id": chat_id,
|
||||
"object": "chat.completion.chunk",
|
||||
@ -127,23 +131,119 @@ async def stream_response(service, response, model, max_tokens):
|
||||
continue
|
||||
|
||||
|
||||
async def chat_response(resp, model, prompt_tokens, max_tokens):
|
||||
last_resp = None
|
||||
for i in reversed(resp):
|
||||
if i != "data: [DONE]" and i.startswith("data: "):
|
||||
try:
|
||||
last_resp = json.loads(i[6:])
|
||||
if not last_resp.get("message"):
|
||||
raise
|
||||
async def chat_response(service, response, prompt_tokens, model, max_tokens):
|
||||
chat_id = f"chatcmpl-{''.join(random.choice(string.ascii_letters + string.digits) for _ in range(29))}"
|
||||
system_fingerprint_list = model_system_fingerprint.get(model, None)
|
||||
system_fingerprint = random.choice(system_fingerprint_list) if system_fingerprint_list else None
|
||||
created_time = int(time.time())
|
||||
finish_reason = "stop"
|
||||
completion_tokens = -1
|
||||
len_last_content = 0
|
||||
last_content_type = None
|
||||
last_recipient = None
|
||||
end = False
|
||||
message_id = None
|
||||
all_text = ""
|
||||
async for chunk in response.aiter_lines():
|
||||
chunk = chunk.decode("utf-8")
|
||||
if end:
|
||||
break
|
||||
try:
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
except Exception:
|
||||
Logger.error(f"Error: {i}")
|
||||
elif not chunk.startswith("data: "):
|
||||
continue
|
||||
usage, system_fingerprint, finish_reason, message = await init_param(last_resp, max_tokens, model, prompt_tokens)
|
||||
else:
|
||||
chunk_old_data = json.loads(chunk[6:])
|
||||
finish_reason = None
|
||||
message = chunk_old_data.get("message", {})
|
||||
status = message.get("status")
|
||||
content = message.get("content", {})
|
||||
recipient = message.get("recipient", "")
|
||||
if not message and chunk_old_data.get("type") == "moderation":
|
||||
delta = {"role": "assistant", "content": moderation_message}
|
||||
finish_reason = "stop"
|
||||
end = True
|
||||
elif status == "in_progress":
|
||||
outer_content_type = content.get("content_type")
|
||||
if outer_content_type == "text":
|
||||
part = content.get("parts", [])[0]
|
||||
if not part:
|
||||
message_id = message.get("id")
|
||||
new_text = ""
|
||||
else:
|
||||
if message_id != message.get("id"):
|
||||
continue
|
||||
new_text = part[len_last_content:]
|
||||
len_last_content = len(part)
|
||||
else:
|
||||
text = content.get("text", "")
|
||||
if outer_content_type == "code" and last_content_type != "code":
|
||||
new_text = "\n```" + recipient + "\n" + text[len_last_content:]
|
||||
elif outer_content_type == "execution_output" and last_content_type != "execution_output":
|
||||
new_text = "\n```" + "Output" + "\n" + text[len_last_content:]
|
||||
else:
|
||||
new_text = text[len_last_content:]
|
||||
len_last_content = len(text)
|
||||
if last_content_type == "code" and outer_content_type != "code":
|
||||
new_text = "\n```\n" + new_text
|
||||
elif last_content_type == "execution_output" and outer_content_type != "execution_output":
|
||||
new_text = "\n```\n" + new_text
|
||||
if recipient == "dalle.text2im" and last_recipient != "dalle.text2im":
|
||||
new_text = "\n```" + "json" + "\n" + new_text
|
||||
delta = {"content": new_text}
|
||||
last_content_type = outer_content_type
|
||||
last_recipient = recipient
|
||||
if completion_tokens >= max_tokens:
|
||||
delta = {}
|
||||
finish_reason = "length"
|
||||
end = True
|
||||
elif status == "finished_successfully":
|
||||
if content.get("content_type") == "multimodal_text":
|
||||
parts = content.get("parts", [])
|
||||
delta = {}
|
||||
for part in parts:
|
||||
inner_content_type = part.get('content_type')
|
||||
if inner_content_type == "image_asset_pointer":
|
||||
last_content_type = "image_asset_pointer"
|
||||
asset_pointer = part.get('asset_pointer').replace('file-service://', '')
|
||||
Logger.debug(f"asset_pointer: {asset_pointer}")
|
||||
image_download_url = await service.get_image_download_url(asset_pointer)
|
||||
Logger.debug(f"image_download_url: {image_download_url}")
|
||||
if image_download_url:
|
||||
delta = {"content": f"\n```\n\n"}
|
||||
else:
|
||||
delta = {"content": f"\n```\nFailed to load the image.\n"}
|
||||
elif not message.get("end_turn") or not message.get("metadata").get("finish_details"):
|
||||
message_id = None
|
||||
len_last_content = 0
|
||||
continue
|
||||
else:
|
||||
delta = {}
|
||||
finish_reason = "stop"
|
||||
end = True
|
||||
else:
|
||||
continue
|
||||
all_text += delta.get("content", "")
|
||||
completion_tokens += 1
|
||||
except Exception:
|
||||
Logger.error(f"Error: {chunk}")
|
||||
continue
|
||||
|
||||
completion_tokens = num_tokens_from_content(all_text, model)
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": all_text,
|
||||
}
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
return {
|
||||
"id": f"chatcmpl-{''.join(random.choice(string.ascii_letters + string.digits) for _ in range(29))}",
|
||||
"id": chat_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
@ -158,29 +258,6 @@ async def chat_response(resp, model, prompt_tokens, max_tokens):
|
||||
}
|
||||
|
||||
|
||||
async def init_param(last_resp, max_tokens, model, prompt_tokens):
|
||||
if last_resp.get("type") == "moderation":
|
||||
message_content = moderation_message
|
||||
completion_tokens = 53
|
||||
finish_reason = "stop"
|
||||
else:
|
||||
message_content = last_resp["message"]["content"]["parts"][0]
|
||||
message_content, completion_tokens, finish_reason = split_tokens_from_content(message_content, max_tokens,
|
||||
model)
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": message_content,
|
||||
}
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
system_fingerprint_list = model_system_fingerprint.get(model, None)
|
||||
system_fingerprint = random.choice(system_fingerprint_list) if system_fingerprint_list else None
|
||||
return usage, system_fingerprint, finish_reason, message
|
||||
|
||||
|
||||
def api_messages_to_chat(api_messages):
|
||||
chat_messages = []
|
||||
for api_message in api_messages:
|
||||
@ -248,6 +325,8 @@ class ChatService:
|
||||
turnstile = resp.get('turnstile', {})
|
||||
arkose_required = arkose.get('required')
|
||||
if arkose_required:
|
||||
if not self.arkose_token_url:
|
||||
raise HTTPException(status_code=403, detail="Arkose service required")
|
||||
arkose_dx = arkose.get("dx")
|
||||
arkose_client = Client()
|
||||
try:
|
||||
@ -258,7 +337,7 @@ class ChatService:
|
||||
)
|
||||
self.arkose_token = r2.json()['token']
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=403, detail="Arkose required")
|
||||
raise HTTPException(status_code=403, detail="Failed to get Arkose token")
|
||||
|
||||
proofofwork_required = proofofwork.get('required')
|
||||
if proofofwork_required:
|
||||
@ -341,33 +420,36 @@ class ChatService:
|
||||
async def send_conversation_for_stream(self):
|
||||
url = f'{self.base_url}/conversation'
|
||||
model = model_proxy.get(self.model, self.model)
|
||||
r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
|
||||
if r.status_code == 200:
|
||||
return stream_response(self, r, model, self.max_tokens)
|
||||
else:
|
||||
rtext = await r.atext()
|
||||
if "application/json" == r.headers.get("Content-Type", ""):
|
||||
detail = json.loads(rtext).get("detail", json.loads(rtext))
|
||||
try:
|
||||
r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
|
||||
if r.status_code == 200:
|
||||
return stream_response(self, r, model, self.max_tokens)
|
||||
else:
|
||||
detail = rtext
|
||||
if r.status_code != 200:
|
||||
if r.status_code == 403:
|
||||
raise HTTPException(status_code=r.status_code, detail="cf-please-wait")
|
||||
raise HTTPException(status_code=r.status_code, detail=detail)
|
||||
rtext = await r.atext()
|
||||
if "application/json" == r.headers.get("Content-Type", ""):
|
||||
detail = json.loads(rtext).get("detail", json.loads(rtext))
|
||||
else:
|
||||
detail = rtext
|
||||
if r.status_code != 200:
|
||||
if r.status_code == 403:
|
||||
raise HTTPException(status_code=r.status_code, detail="cf-please-wait")
|
||||
raise HTTPException(status_code=r.status_code, detail=detail)
|
||||
except HTTPException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=str(e))
|
||||
|
||||
async def send_conversation(self):
|
||||
url = f'{self.base_url}/conversation'
|
||||
model = model_proxy.get(self.model, self.model)
|
||||
try:
|
||||
r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600)
|
||||
r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
|
||||
if r.status_code == 200:
|
||||
rtext = r.text.split("\n")
|
||||
return await chat_response(rtext, model, self.prompt_tokens, self.max_tokens)
|
||||
return await chat_response(self, r, self.prompt_tokens, model, self.max_tokens)
|
||||
else:
|
||||
rtext = await r.atext()
|
||||
if "application/json" == r.headers.get("Content-Type", ""):
|
||||
detail = r.json().get("detail", r.json())
|
||||
detail = json.loads(rtext).get("detail", json.loads(rtext))
|
||||
else:
|
||||
detail = r.content
|
||||
detail = rtext
|
||||
if r.status_code == 403:
|
||||
raise HTTPException(status_code=r.status_code, detail="cf-please-wait")
|
||||
raise HTTPException(status_code=r.status_code, detail=detail)
|
||||
|
||||
@ -6,6 +6,57 @@ from utils.Client import Client
|
||||
from utils.config import chatgpt_base_url_list, proxy_url_list
|
||||
|
||||
|
||||
headers_reject_list = [
|
||||
"x-real-ip",
|
||||
"x-forwarded-for",
|
||||
"x-forwarded-proto",
|
||||
"x-forwarded-port",
|
||||
"x-forwarded-host",
|
||||
"x-forwarded-server",
|
||||
"cf-warp-tag-id",
|
||||
"cf-visitor",
|
||||
"cf-ray",
|
||||
"cf-connecting-ip",
|
||||
"cf-ipcountry",
|
||||
"cdn-loop",
|
||||
"remote-host",
|
||||
"x-frame-options",
|
||||
"x-xss-protection",
|
||||
"x-content-type-options",
|
||||
"content-security-policy",
|
||||
"host",
|
||||
"cookie",
|
||||
"connection",
|
||||
"content-length",
|
||||
"content-encoding",
|
||||
"x-middleware-prefetch",
|
||||
"x-nextjs-data",
|
||||
"purpose",
|
||||
"x-forwarded-uri",
|
||||
"x-forwarded-path",
|
||||
"x-forwarded-method",
|
||||
"x-forwarded-protocol",
|
||||
"x-forwarded-scheme",
|
||||
"cf-request-id",
|
||||
"cf-worker",
|
||||
"cf-access-client-id",
|
||||
"cf-access-client-device-type",
|
||||
"cf-access-client-device-model",
|
||||
"cf-access-client-device-name",
|
||||
"cf-access-client-device-brand",
|
||||
"x-middleware-prefetch",
|
||||
"x-forwarded-for",
|
||||
"x-forwarded-host",
|
||||
"x-forwarded-proto",
|
||||
"x-forwarded-server",
|
||||
"x-real-ip",
|
||||
"x-forwarded-port",
|
||||
"cf-connecting-ip",
|
||||
"cf-ipcountry",
|
||||
"cf-ray",
|
||||
"cf-visitor",
|
||||
]
|
||||
|
||||
async def chatgpt_reverse_proxy(request: Request, path: str):
|
||||
try:
|
||||
base_url = random.choice(chatgpt_base_url_list)
|
||||
@ -23,9 +74,7 @@ async def chatgpt_reverse_proxy(request: Request, path: str):
|
||||
params = dict(request.query_params)
|
||||
headers = {
|
||||
key: value for key, value in request.headers.items()
|
||||
if (key.lower() not in ["host", "origin", "referer", "user-agent", "authorization"]
|
||||
and not
|
||||
(key.lower().startswith("cf") or key.lower().startswith("cdn") or key.lower().startswith("x")))
|
||||
if (key.lower() not in ["host", "origin", "referer", "user-agent", "authorization"] and key.lower() not in headers_reject_list)
|
||||
}
|
||||
cookies = dict(request.cookies)
|
||||
|
||||
|
||||
@ -8,13 +8,13 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||
|
||||
|
||||
async def verify_token(token: str = Depends(oauth2_scheme)):
|
||||
if not authorization_list:
|
||||
if token and token.startswith("eyJhbGciOi"):
|
||||
return token
|
||||
elif token in authorization_list:
|
||||
return token
|
||||
elif token and token.startswith("eyJhbGciOi"):
|
||||
return token
|
||||
elif len(token) == 45:
|
||||
elif token and len(token) == 45:
|
||||
return await rt2ac(token)
|
||||
elif not authorization_list:
|
||||
return token
|
||||
elif token and token in authorization_list:
|
||||
return token
|
||||
else:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user