mirror of
https://github.com/lanqian528/chat2api.git
synced 2026-06-13 21:02:46 +08:00
feat 新增两个环境变量用来控制是否对4消息进行限制
This commit is contained in:
parent
308f2aea78
commit
7f45621bce
15
app.py
15
app.py
@ -1,9 +1,8 @@
|
||||
import asyncio
|
||||
|
||||
import uvicorn
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
|
||||
from chat2api import app
|
||||
from chatgpt.chatLimit import clean_background_task
|
||||
from chatgpt.chatLimit import clean_dict
|
||||
|
||||
log_config = uvicorn.config.LOGGING_CONFIG
|
||||
default_format = "%(asctime)s | %(levelname)s | %(message)s"
|
||||
@ -11,11 +10,15 @@ access_format = r'%(asctime)s | %(levelname)s | %(client_addr)s: %(request_line)
|
||||
log_config["formatters"]["default"]["fmt"] = default_format
|
||||
log_config["formatters"]["access"]["fmt"] = access_format
|
||||
|
||||
scheduler = BackgroundScheduler()
|
||||
|
||||
# 在应用启动时启动后台任务
|
||||
|
||||
# 用于自动更新限制access_token的字典
|
||||
# 避免过多access_token没有销毁,而堆积
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
clean_task = asyncio.create_task(clean_background_task())
|
||||
async def app_start():
|
||||
scheduler.add_job(id='updateLimit_run', func=clean_dict, trigger='cron', hour=3, minute=0)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
uvicorn.run("chat2api:app", host="0.0.0.0", port=5005)
|
||||
|
||||
12
chat2api.py
12
chat2api.py
@ -9,11 +9,12 @@ from fastapi.templating import Jinja2Templates
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from chatgpt.ChatService import ChatService
|
||||
from chatgpt.chatLimit import check_isLimit, handle_request_limit
|
||||
from chatgpt.chatLimit import handle_request_limit
|
||||
from chatgpt.reverseProxy import chatgpt_reverse_proxy
|
||||
from utils.Logger import logger
|
||||
from utils.authorization import verify_token, token_list
|
||||
from utils.config import api_prefix
|
||||
from utils.config import enable_limit, limit_status_code
|
||||
from utils.retry import async_retry
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
@ -32,16 +33,12 @@ app.add_middleware(
|
||||
|
||||
async def to_send_conversation(request_data, access_token):
|
||||
try:
|
||||
limit_response = await handle_request_limit(request_data, access_token)
|
||||
if limit_response:
|
||||
raise HTTPException(status_code=429, detail=limit_response)
|
||||
chat_service = ChatService(access_token)
|
||||
await chat_service.set_dynamic_data(request_data)
|
||||
await chat_service.get_chat_requirements()
|
||||
return chat_service
|
||||
except HTTPException as e:
|
||||
await chat_service.close_client()
|
||||
await check_isLimit(e, access_token)
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
except Exception as e:
|
||||
await chat_service.close_client()
|
||||
@ -56,6 +53,11 @@ async def send_conversation(request: Request, token=Depends(verify_token)):
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"})
|
||||
|
||||
if enable_limit:
|
||||
limit_response = await handle_request_limit(request_data, token)
|
||||
if limit_response:
|
||||
raise HTTPException(status_code=int(limit_status_code), detail=limit_response)
|
||||
|
||||
chat_service = await async_retry(to_send_conversation, request_data, token)
|
||||
try:
|
||||
await chat_service.prepare_send_conversation()
|
||||
|
||||
@ -10,6 +10,7 @@ from starlette.concurrency import run_in_threadpool
|
||||
from api.files import get_image_size, get_file_extension, determine_file_use_case
|
||||
from api.models import model_proxy
|
||||
from chatgpt.chatFormat import api_messages_to_chat, stream_response, wss_stream_response, format_not_stream_response
|
||||
from chatgpt.chatLimit import check_isLimit
|
||||
from chatgpt.proofofWork import get_config, get_dpl, get_answer_token, get_requirements_token
|
||||
from chatgpt.wssClient import ac2wss, set_wss
|
||||
from utils.Client import Client
|
||||
@ -79,7 +80,8 @@ class ChatService:
|
||||
self.base_url = self.host_url + "/backend-anon"
|
||||
|
||||
await get_dpl(self)
|
||||
self.s.session.cookies.set("__Secure-next-auth.callback-url", "https%3A%2F%2Fchatgpt.com;", domain=self.host_url.split("://")[1], secure=True)
|
||||
self.s.session.cookies.set("__Secure-next-auth.callback-url", "https%3A%2F%2Fchatgpt.com;",
|
||||
domain=self.host_url.split("://")[1], secure=True)
|
||||
|
||||
async def get_wss_url(self):
|
||||
url = f'{self.base_url}/register-websocket'
|
||||
@ -124,9 +126,11 @@ class ChatService:
|
||||
if proofofwork_required:
|
||||
proofofwork_diff = proofofwork.get("difficulty")
|
||||
if proofofwork_diff <= pow_difficulty:
|
||||
raise HTTPException(status_code=403, detail=f"Proof of work difficulty too high: {proofofwork_diff}")
|
||||
raise HTTPException(status_code=403,
|
||||
detail=f"Proof of work difficulty too high: {proofofwork_diff}")
|
||||
proofofwork_seed = proofofwork.get("seed")
|
||||
self.proof_token, solved = await run_in_threadpool(get_answer_token, proofofwork_seed, proofofwork_diff, config)
|
||||
self.proof_token, solved = await run_in_threadpool(get_answer_token, proofofwork_seed,
|
||||
proofofwork_diff, config)
|
||||
if not solved:
|
||||
raise HTTPException(status_code=403, detail="Failed to solve proof of work")
|
||||
|
||||
@ -170,7 +174,6 @@ class ChatService:
|
||||
if r.status_code == 429:
|
||||
raise HTTPException(status_code=r.status_code, detail="rate-limit")
|
||||
raise HTTPException(status_code=r.status_code, detail=detail)
|
||||
|
||||
except HTTPException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.detail)
|
||||
except Exception as e:
|
||||
@ -243,11 +246,14 @@ class ChatService:
|
||||
raise HTTPException(status_code=e.status_code, detail=str(e))
|
||||
url = f'{self.base_url}/conversation'
|
||||
stream = self.data.get("stream", False)
|
||||
r = await self.s.post_stream(url, headers=self.chat_headers, json=self.chat_request, timeout=10, stream=True)
|
||||
r = await self.s.post_stream(url, headers=self.chat_headers, json=self.chat_request, timeout=10,
|
||||
stream=True)
|
||||
if r.status_code != 200:
|
||||
rtext = await r.atext()
|
||||
if "application/json" == r.headers.get("Content-Type", ""):
|
||||
detail = json.loads(rtext).get("detail", json.loads(rtext))
|
||||
if r.status_code == 429:
|
||||
check_isLimit(detail, access_token=self.access_token)
|
||||
else:
|
||||
if "cf-please-wait" in rtext:
|
||||
raise HTTPException(status_code=r.status_code, detail="cf-please-wait")
|
||||
@ -260,7 +266,9 @@ class ChatService:
|
||||
if "text/event-stream" in content_type and stream:
|
||||
return stream_response(self, r.aiter_lines(), self.target_model, self.max_tokens)
|
||||
elif "text/event-stream" in content_type and not stream:
|
||||
return await format_not_stream_response(stream_response(self, r.aiter_lines(), self.target_model, self.max_tokens), self.prompt_tokens, self.max_tokens, self.target_model)
|
||||
return await format_not_stream_response(
|
||||
stream_response(self, r.aiter_lines(), self.target_model, self.max_tokens), self.prompt_tokens,
|
||||
self.max_tokens, self.target_model)
|
||||
elif "application/json" in content_type:
|
||||
rtext = await r.atext()
|
||||
resp = json.loads(rtext)
|
||||
@ -275,7 +283,9 @@ class ChatService:
|
||||
if stream and isinstance(wss_r, types.AsyncGeneratorType):
|
||||
return stream_response(self, wss_r, self.target_model, self.max_tokens)
|
||||
else:
|
||||
return await format_not_stream_response(stream_response(self, wss_r, self.target_model, self.max_tokens), self.prompt_tokens, self.max_tokens, self.target_model)
|
||||
return await format_not_stream_response(
|
||||
stream_response(self, wss_r, self.target_model, self.max_tokens), self.prompt_tokens,
|
||||
self.max_tokens, self.target_model)
|
||||
finally:
|
||||
if not isinstance(wss_r, types.AsyncGeneratorType):
|
||||
await self.ws.close()
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from utils.Logger import logger
|
||||
|
||||
# 开启线程锁
|
||||
@ -11,56 +9,53 @@ lock = threading.Lock()
|
||||
limit_access_token = {}
|
||||
|
||||
|
||||
async def check_isLimit(e, access_token):
|
||||
if e.status_code == 429:
|
||||
clearTime = time.time() + e.detail.get('clears_in')
|
||||
clearDate = datetime.fromtimestamp(clearTime)
|
||||
initial_access_list(access_token, clearTime, clearDate)
|
||||
def check_isLimit(detail, access_token):
|
||||
if detail.get('clears_in'):
|
||||
clearTime = time.time() + detail.get('clears_in')
|
||||
initial_access_list(access_token, clearTime)
|
||||
|
||||
|
||||
async def initial_access_list(key, clear_time):
|
||||
def initial_access_list(key, clear_time):
|
||||
with lock:
|
||||
limit_access_token[key] = clear_time
|
||||
logger.info(
|
||||
f"{key}: Reached 429 limit, will be cleared at {datetime.fromtimestamp(clear_time).replace(second=0, microsecond=0)}")
|
||||
f"{key[:40]}: Reached 429 limit, will be cleared at {datetime.fromtimestamp(clear_time).replace(second=0, microsecond=0)}")
|
||||
|
||||
|
||||
async def remove_refresh_list(key):
|
||||
def remove_refresh_list(key):
|
||||
with lock:
|
||||
if key in limit_access_token:
|
||||
limit_access_token.pop(key)
|
||||
logger.info(f"Removed limit for {key[:40]}.")
|
||||
else:
|
||||
logger.info(f"Access token {key[:40]} not found in limit_access_token.")
|
||||
logger.info(f"Access token: {key[:40]} not found in limit_access_token.")
|
||||
|
||||
|
||||
async def handle_request_limit(request_data, access_token):
|
||||
if "gpt-4" in request_data.get("model") and access_token in limit_access_token:
|
||||
is_clear = limit_access_token[access_token] < time.time()
|
||||
if is_clear:
|
||||
await remove_refresh_list(access_token)
|
||||
else:
|
||||
clear_date = datetime.fromtimestamp(limit_access_token[access_token])
|
||||
logger.info(f"Request limit exceeded. "
|
||||
f"You can continue with the default model now, "
|
||||
f"or try again after {clear_date.replace(second=0, microsecond=0)}")
|
||||
return f"Request limit exceeded. You can continue with the default model now, or try again after {clear_date.replace(second=0, microsecond=0)}"
|
||||
return None
|
||||
try:
|
||||
if "gpt-4" in request_data.get("model", "gpt-3.5-turbo") and access_token in limit_access_token:
|
||||
limit_time = limit_access_token[access_token]
|
||||
is_clear = limit_time < time.time()
|
||||
if is_clear:
|
||||
remove_refresh_list(access_token)
|
||||
else:
|
||||
clear_date = datetime.fromtimestamp(limit_time).replace(second=0, microsecond=0)
|
||||
logger.info(f"Request limit exceeded. "
|
||||
f"You can continue with the default model now, "
|
||||
f"or try again after {clear_date}")
|
||||
return f"Request limit exceeded. You can continue with the default model now, or try again after {clear_date}"
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
|
||||
|
||||
# 清理函数,用于删除长时间未使用的键值对
|
||||
def clean_dict():
|
||||
logger.info(f"==========================================")
|
||||
logger.info("开始执行清理过期的limit_access_token......")
|
||||
current_time = time.time()
|
||||
keys_to_remove = [key for key, clear_time in limit_access_token.items() if clear_time < current_time]
|
||||
for key in keys_to_remove:
|
||||
with lock:
|
||||
del limit_access_token[key]
|
||||
|
||||
|
||||
# 后台任务,定期执行清理函数
|
||||
async def clean_background_task():
|
||||
while True:
|
||||
clean_dict()
|
||||
# 一天执行一次,防止limit_access_token过多
|
||||
await asyncio.sleep(3600 * 24)
|
||||
|
||||
@ -6,4 +6,5 @@ python-dotenv
|
||||
websockets
|
||||
pillow
|
||||
pybase64
|
||||
jinja2
|
||||
jinja2
|
||||
APScheduler
|
||||
|
||||
@ -29,6 +29,7 @@ retry_times = int(os.getenv('RETRY_TIMES', 3))
|
||||
enable_gateway = is_true(os.getenv('ENABLE_GATEWAY', True))
|
||||
conversation_only = is_true(os.getenv('CONVERSATION_ONLY', False))
|
||||
enable_limit = is_true(os.getenv('ENABLE_LIMIT', False))
|
||||
limit_status_code = os.getenv('LIMIT_STATUS_CODE', 429)
|
||||
|
||||
authorization_list = authorization.split(',') if authorization else []
|
||||
chatgpt_base_url_list = chatgpt_base_url.split(',') if chatgpt_base_url else []
|
||||
@ -50,4 +51,5 @@ logger.info("RETRY_TIMES: " + str(retry_times))
|
||||
logger.info("ENABLE_GATEWAY: " + str(enable_gateway))
|
||||
logger.info("CONVERSATION_ONLY: " + str(conversation_only))
|
||||
logger.info("ENABLE_LIMIT: " + str(enable_limit))
|
||||
logger.info("LIMIT_STATUS_CODE: " + str(limit_status_code))
|
||||
logger.info("-" * 60)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user