Skip to content

Commit

Permalink
feat: Add rate limiting functionality with custom handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
unclecode committed Jul 8, 2024
1 parent 4d283ab commit 65ed1ae
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@
from crawl4ai.web_crawler import WebCrawler
from crawl4ai.database import get_total_count, clear_db

import time
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded

# load .env file
from dotenv import load_dotenv
load_dotenv()

# Configuration
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
Expand All @@ -30,6 +39,72 @@

app = FastAPI()

# Initialize rate limiter
def rate_limit_key_func(request: Request):
access_token = request.headers.get("access-token")
if access_token == os.environ.get('ACCESS_TOKEN'):
return None
return get_remote_address(request)

limiter = Limiter(key_func=rate_limit_key_func)
app.state.limiter = limiter

# Dictionary to store last request times for each client
last_request_times = {}

def get_rate_limit():
limit = os.environ.get('ACCESS_PER_MIN', "5")
return f"{limit}/minute"

# Custom rate limit exceeded handler
async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
try_after = last_request_times.get(request.client.host, 0) + 10 - time.time()
reset_at = time.time() + try_after
return JSONResponse(
status_code=429,
content={
"detail": "Rate limit exceeded",
"limit": str(exc.limit.limit),
"reset_at": reset_at,
"message": f"You have exceeded the rate limit of {exc.limit.limit}. Please try again after {try_after} seconds."
}
)

app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler)


# Middleware for token-based bypass and per-request limit
class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10))
access_token = request.headers.get("access-token")
if access_token == os.environ.get('ACCESS_TOKEN'):
return await call_next(request)

path = request.url.path
if path in ["/crawl", "/old"]:
client_ip = request.client.host
current_time = time.time()

# Check time since last request
if client_ip in last_request_times:
time_since_last_request = current_time - last_request_times[client_ip]
if time_since_last_request < SPAN:
return JSONResponse(
status_code=429,
content={
"detail": "Too many requests",
"message": "Rate limit exceeded. Please wait 10 seconds between requests.",
"retry_after": max(0, SPAN - time_since_last_request)
}
)

last_request_times[client_ip] = current_time

return await call_next(request)

app.add_middleware(RateLimitMiddleware)

# CORS configuration
origins = ["*"] # Allow all origins
app.add_middleware(
Expand Down Expand Up @@ -73,6 +148,7 @@ def read_root():
return RedirectResponse(url="/mkdocs")

@app.get("/old", response_class=HTMLResponse)
@limiter.limit(get_rate_limit())
async def read_index(request: Request):
partials_dir = os.path.join(__location__, "pages", "partial")
partials = {}
Expand Down Expand Up @@ -107,6 +183,7 @@ def import_strategy(module_name: str, class_name: str, *args, **kwargs):
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")

@app.post("/crawl")
@limiter.limit(get_rate_limit())
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}")
global current_requests
Expand Down
Empty file added middlewares.py
Empty file.

0 comments on commit 65ed1ae

Please sign in to comment.