feat 新增两个环境变量用来控制是否对4消息进行限制

This commit is contained in:
Yanyutin753 2024-05-23 14:49:51 +08:00
parent 308f2aea78
commit 7f45621bce
6 changed files with 62 additions and 49 deletions

15
app.py
View File

@ -1,9 +1,8 @@
import asyncio
import uvicorn
from apscheduler.schedulers.background import BackgroundScheduler
from chat2api import app
from chatgpt.chatLimit import clean_background_task
from chatgpt.chatLimit import clean_dict
log_config = uvicorn.config.LOGGING_CONFIG
default_format = "%(asctime)s | %(levelname)s | %(message)s"
@ -11,11 +10,15 @@ access_format = r'%(asctime)s | %(levelname)s | %(client_addr)s: %(request_line)
log_config["formatters"]["default"]["fmt"] = default_format
log_config["formatters"]["access"]["fmt"] = access_format
scheduler = BackgroundScheduler()
# 在应用启动时启动后台任务
# 用于自动更新限制access_token的字典
# 避免过多access_token没有销毁,而堆积
@app.on_event("startup")
async def startup_event():
clean_task = asyncio.create_task(clean_background_task())
async def app_start():
scheduler.add_job(id='updateLimit_run', func=clean_dict, trigger='cron', hour=3, minute=0)
scheduler.start()
uvicorn.run("chat2api:app", host="0.0.0.0", port=5005)

View File

@ -9,11 +9,12 @@ from fastapi.templating import Jinja2Templates
from starlette.background import BackgroundTask
from chatgpt.ChatService import ChatService
from chatgpt.chatLimit import check_isLimit, handle_request_limit
from chatgpt.chatLimit import handle_request_limit
from chatgpt.reverseProxy import chatgpt_reverse_proxy
from utils.Logger import logger
from utils.authorization import verify_token, token_list
from utils.config import api_prefix
from utils.config import enable_limit, limit_status_code
from utils.retry import async_retry
warnings.filterwarnings("ignore")
@ -32,16 +33,12 @@ app.add_middleware(
async def to_send_conversation(request_data, access_token):
try:
limit_response = await handle_request_limit(request_data, access_token)
if limit_response:
raise HTTPException(status_code=429, detail=limit_response)
chat_service = ChatService(access_token)
await chat_service.set_dynamic_data(request_data)
await chat_service.get_chat_requirements()
return chat_service
except HTTPException as e:
await chat_service.close_client()
await check_isLimit(e, access_token)
raise HTTPException(status_code=e.status_code, detail=e.detail)
except Exception as e:
await chat_service.close_client()
@ -56,6 +53,11 @@ async def send_conversation(request: Request, token=Depends(verify_token)):
except Exception:
raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"})
if enable_limit:
limit_response = await handle_request_limit(request_data, token)
if limit_response:
raise HTTPException(status_code=int(limit_status_code), detail=limit_response)
chat_service = await async_retry(to_send_conversation, request_data, token)
try:
await chat_service.prepare_send_conversation()

View File

@ -10,6 +10,7 @@ from starlette.concurrency import run_in_threadpool
from api.files import get_image_size, get_file_extension, determine_file_use_case
from api.models import model_proxy
from chatgpt.chatFormat import api_messages_to_chat, stream_response, wss_stream_response, format_not_stream_response
from chatgpt.chatLimit import check_isLimit
from chatgpt.proofofWork import get_config, get_dpl, get_answer_token, get_requirements_token
from chatgpt.wssClient import ac2wss, set_wss
from utils.Client import Client
@ -79,7 +80,8 @@ class ChatService:
self.base_url = self.host_url + "/backend-anon"
await get_dpl(self)
self.s.session.cookies.set("__Secure-next-auth.callback-url", "https%3A%2F%2Fchatgpt.com;", domain=self.host_url.split("://")[1], secure=True)
self.s.session.cookies.set("__Secure-next-auth.callback-url", "https%3A%2F%2Fchatgpt.com;",
domain=self.host_url.split("://")[1], secure=True)
async def get_wss_url(self):
url = f'{self.base_url}/register-websocket'
@ -124,9 +126,11 @@ class ChatService:
if proofofwork_required:
proofofwork_diff = proofofwork.get("difficulty")
if proofofwork_diff <= pow_difficulty:
raise HTTPException(status_code=403, detail=f"Proof of work difficulty too high: {proofofwork_diff}")
raise HTTPException(status_code=403,
detail=f"Proof of work difficulty too high: {proofofwork_diff}")
proofofwork_seed = proofofwork.get("seed")
self.proof_token, solved = await run_in_threadpool(get_answer_token, proofofwork_seed, proofofwork_diff, config)
self.proof_token, solved = await run_in_threadpool(get_answer_token, proofofwork_seed,
proofofwork_diff, config)
if not solved:
raise HTTPException(status_code=403, detail="Failed to solve proof of work")
@ -170,7 +174,6 @@ class ChatService:
if r.status_code == 429:
raise HTTPException(status_code=r.status_code, detail="rate-limit")
raise HTTPException(status_code=r.status_code, detail=detail)
except HTTPException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
except Exception as e:
@ -243,11 +246,14 @@ class ChatService:
raise HTTPException(status_code=e.status_code, detail=str(e))
url = f'{self.base_url}/conversation'
stream = self.data.get("stream", False)
r = await self.s.post_stream(url, headers=self.chat_headers, json=self.chat_request, timeout=10, stream=True)
r = await self.s.post_stream(url, headers=self.chat_headers, json=self.chat_request, timeout=10,
stream=True)
if r.status_code != 200:
rtext = await r.atext()
if "application/json" == r.headers.get("Content-Type", ""):
detail = json.loads(rtext).get("detail", json.loads(rtext))
if r.status_code == 429:
check_isLimit(detail, access_token=self.access_token)
else:
if "cf-please-wait" in rtext:
raise HTTPException(status_code=r.status_code, detail="cf-please-wait")
@ -260,7 +266,9 @@ class ChatService:
if "text/event-stream" in content_type and stream:
return stream_response(self, r.aiter_lines(), self.target_model, self.max_tokens)
elif "text/event-stream" in content_type and not stream:
return await format_not_stream_response(stream_response(self, r.aiter_lines(), self.target_model, self.max_tokens), self.prompt_tokens, self.max_tokens, self.target_model)
return await format_not_stream_response(
stream_response(self, r.aiter_lines(), self.target_model, self.max_tokens), self.prompt_tokens,
self.max_tokens, self.target_model)
elif "application/json" in content_type:
rtext = await r.atext()
resp = json.loads(rtext)
@ -275,7 +283,9 @@ class ChatService:
if stream and isinstance(wss_r, types.AsyncGeneratorType):
return stream_response(self, wss_r, self.target_model, self.max_tokens)
else:
return await format_not_stream_response(stream_response(self, wss_r, self.target_model, self.max_tokens), self.prompt_tokens, self.max_tokens, self.target_model)
return await format_not_stream_response(
stream_response(self, wss_r, self.target_model, self.max_tokens), self.prompt_tokens,
self.max_tokens, self.target_model)
finally:
if not isinstance(wss_r, types.AsyncGeneratorType):
await self.ws.close()

View File

@ -1,8 +1,6 @@
import asyncio
import threading
import time
from datetime import datetime
from utils.Logger import logger
# 开启线程锁
@ -11,56 +9,53 @@ lock = threading.Lock()
limit_access_token = {}
async def check_isLimit(e, access_token):
if e.status_code == 429:
clearTime = time.time() + e.detail.get('clears_in')
clearDate = datetime.fromtimestamp(clearTime)
initial_access_list(access_token, clearTime, clearDate)
def check_isLimit(detail, access_token):
if detail.get('clears_in'):
clearTime = time.time() + detail.get('clears_in')
initial_access_list(access_token, clearTime)
async def initial_access_list(key, clear_time):
def initial_access_list(key, clear_time):
with lock:
limit_access_token[key] = clear_time
logger.info(
f"{key}: Reached 429 limit, will be cleared at {datetime.fromtimestamp(clear_time).replace(second=0, microsecond=0)}")
f"{key[:40]}: Reached 429 limit, will be cleared at {datetime.fromtimestamp(clear_time).replace(second=0, microsecond=0)}")
async def remove_refresh_list(key):
def remove_refresh_list(key):
with lock:
if key in limit_access_token:
limit_access_token.pop(key)
logger.info(f"Removed limit for {key[:40]}.")
else:
logger.info(f"Access token {key[:40]} not found in limit_access_token.")
logger.info(f"Access token: {key[:40]} not found in limit_access_token.")
async def handle_request_limit(request_data, access_token):
if "gpt-4" in request_data.get("model") and access_token in limit_access_token:
is_clear = limit_access_token[access_token] < time.time()
if is_clear:
await remove_refresh_list(access_token)
else:
clear_date = datetime.fromtimestamp(limit_access_token[access_token])
logger.info(f"Request limit exceeded. "
f"You can continue with the default model now, "
f"or try again after {clear_date.replace(second=0, microsecond=0)}")
return f"Request limit exceeded. You can continue with the default model now, or try again after {clear_date.replace(second=0, microsecond=0)}"
return None
try:
if "gpt-4" in request_data.get("model", "gpt-3.5-turbo") and access_token in limit_access_token:
limit_time = limit_access_token[access_token]
is_clear = limit_time < time.time()
if is_clear:
remove_refresh_list(access_token)
else:
clear_date = datetime.fromtimestamp(limit_time).replace(second=0, microsecond=0)
logger.info(f"Request limit exceeded. "
f"You can continue with the default model now, "
f"or try again after {clear_date}")
return f"Request limit exceeded. You can continue with the default model now, or try again after {clear_date}"
return None
except Exception as e:
logger.error(e)
return None
# 清理函数,用于删除长时间未使用的键值对
def clean_dict():
logger.info(f"==========================================")
logger.info("开始执行清理过期的limit_access_token......")
current_time = time.time()
keys_to_remove = [key for key, clear_time in limit_access_token.items() if clear_time < current_time]
for key in keys_to_remove:
with lock:
del limit_access_token[key]
# 后台任务,定期执行清理函数
async def clean_background_task():
while True:
clean_dict()
# 一天执行一次防止limit_access_token过多
await asyncio.sleep(3600 * 24)

View File

@ -6,4 +6,5 @@ python-dotenv
websockets
pillow
pybase64
jinja2
jinja2
APScheduler

View File

@ -29,6 +29,7 @@ retry_times = int(os.getenv('RETRY_TIMES', 3))
enable_gateway = is_true(os.getenv('ENABLE_GATEWAY', True))
conversation_only = is_true(os.getenv('CONVERSATION_ONLY', False))
enable_limit = is_true(os.getenv('ENABLE_LIMIT', False))
limit_status_code = os.getenv('LIMIT_STATUS_CODE', 429)
authorization_list = authorization.split(',') if authorization else []
chatgpt_base_url_list = chatgpt_base_url.split(',') if chatgpt_base_url else []
@ -50,4 +51,5 @@ logger.info("RETRY_TIMES: " + str(retry_times))
logger.info("ENABLE_GATEWAY: " + str(enable_gateway))
logger.info("CONVERSATION_ONLY: " + str(conversation_only))
logger.info("ENABLE_LIMIT: " + str(enable_limit))
logger.info("LIMIT_STATUS_CODE: " + str(limit_status_code))
logger.info("-" * 60)