import argparse
import asyncio
import time
from typing import Any, Callable, Dict, Optional

import httpx
import redis
import sqlite3

class RateLimitExceeded(Exception):
    """Custom exception for rate limit exceeded errors."""
    pass

class RateLimiter:
    """
    Middleware for FastAPI that implements sliding window rate limiting.

    Supports Redis or SQLite backends. Configurable per-route limits.
    """

    def __init__(self, backend: str = "redis", backend_config: Optional[Dict[str, Any]] = None, default_limit: int = 10, default_window: int = 60):
        """
        Initializes the RateLimiter.

        Args:
            backend: The backend to use ("redis" or "sqlite"). Defaults to "redis".
            backend_config: Configuration for the backend. Defaults to None.
            default_limit: The default rate limit. Defaults to 10.
            default_window: The default time window in seconds. Defaults to 60.
        """
        self.backend = backend
        self.backend_config = backend_config or {}
        self.default_limit = default_limit
        self.default_window = default_window
        self.client = self._connect()

    def _connect(self) -> Any:
        """Connects to the chosen backend."""
        if self.backend == "redis":
            return redis.Redis(**self.backend_config)
        elif self.backend == "sqlite":
            conn = sqlite3.connect(**self.backend_config)
            cursor = conn.cursor()
            cursor.execute("""
                CREATE TABLE IF NOT EXISTS rate_limits (
                    route_key TEXT PRIMARY KEY,
                    timestamp INTEGER,
                    count INTEGER
                )
            """)
            conn.commit()
            return conn
        else:
            raise ValueError(f"Unsupported backend: {self.backend}")

    async def __call__(self, request: Any, call_next: Callable[[], Any]) -> Any:
        """
        Middleware function that checks and enforces rate limits.

        Args:
            request: The incoming request object.
            call_next: The next middleware or route handler.

        Returns:
            The response from the next middleware or route handler.

        Raises:
            RateLimitExceeded: If the rate limit is exceeded.
        """
        route_key = str(request.url)  # Use the URL as the key
        limit, window = self._get_limits(request)

        if self.backend == "redis":
            count = await self._redis_get_count(route_key, limit, window)
        else:
            count = await self._sqlite_get_count(route_key, limit, window)

        if count > limit:
            raise RateLimitExceeded("Rate limit exceeded")

        response = await call_next()
        return response

    def _get_limits(self, request: Any) -> tuple[int, int]:
        """
        Retrieves the rate limit and window for the given request.

        Args:
            request: The incoming request object.

        Returns:
            A tuple containing the rate limit and window.
        """
        # Implement route-specific limits here if needed.
        # For example, check request.headers or request.path
        return self.default_limit, self.default_window

    async def _redis_get_count(self, route_key: str, limit: int, window: int) -> int:
        """
        Gets the request count for a given route key from Redis.

        Args:
            route_key: The key to use for identifying the route.
            limit: The rate limit.
            window: The time window in seconds.

        Returns:
            The request count.
        """
        now = int(time.time())
        with self.client.pipeline() as pipe:
            pipe.zremrangebyscore("rate_limits", 0, now - window)
            pipe.zadd("rate_limits", {route_key: now})
            pipe.zcount("rate_limits", route_key, route_key)
            count = pipe.execute()[2]
        return count

    async def _sqlite_get_count(self, route_key: str, limit: int, window: int) -> int:
        """
        Gets the request count for a given route key from SQLite.

        Args:
            route_key: The key to use for identifying the route.
            limit: The rate limit.
            window: The time window in seconds.

        Returns:
            The request count.
        """
        now = int(time.time())
        with self.client:
            self.client.execute("""
                DELETE FROM rate_limits WHERE timestamp < ?;
            """, (now - window,))
            cursor = self.client.cursor()
            cursor.execute("""
                INSERT INTO rate_limits (route_key, timestamp) VALUES (?, ?);
            """, (route_key, now))
            self.client.commit()
            cursor.execute("""
                SELECT COUNT(*) FROM rate_limits WHERE route_key = ?;
            """, (route_key,))
            count = cursor.fetchone()[0]
        return count

def main():
    parser = argparse.ArgumentParser(description="Rate Limiter Example")
    parser.add_argument("--backend", choices=["redis", "sqlite"], default="redis", help="Backend to use (redis or sqlite)")
    parser.add_argument("--redis-host", default="localhost", help="Redis host")
    parser.add_argument("--redis-port", type=int, default=6379, help="Redis port")
    parser.add_argument("--sqlite-db", default="rate_limit.db", help="SQLite database file")

    args = parser.parse_args()

    backend_config = {}
    if args.backend == "redis":
        backend_config = {"host": args.redis_host, "port": args.redis_port}
    elif args.backend == "sqlite":
        backend_config = {"database": args.sqlite_db}

    rate_limiter = RateLimiter(backend=args.backend, backend_config=backend_config)

    # Example usage (replace with your FastAPI app)
    async def example_route():
        return {"message": "Hello, world!"}

    async def apply_rate_limit(route):
        try:
            return await route()
        except RateLimitExceeded as e:
            print(f"Rate limit exceeded: {e}")
            return {"error": "Rate limit exceeded"}

    async def main_function():
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route)
        await apply_rate_limit(example_route) # This should exceed the limit

    if __name__ == "__main__":
        asyncio.run(main_function())

if __name__ == "__main__":
    main()