import argparse
import asyncio
import logging
import time
from typing import Dict, Optional

import httpx
import sqlite3
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class CacheProxy:
    """
    A reverse proxy with caching functionality.
    """

    def __init__(
        self,
        upstream_url: str,
        db_path: str = "cache.db",
        ttl: int = 3600,
    ):
        """
        Initializes the CacheProxy.

        Args:
            upstream_url: The URL of the upstream server to proxy to.
            db_path: The path to the SQLite database file for caching.
            ttl: The time-to-live (TTL) in seconds for cached responses.
        """
        self.upstream_url = upstream_url
        self.db_path = db_path
        self.ttl = ttl
        self.client = httpx.AsyncClient()
        self._create_table()

    def _create_table(self):
        """
        Creates the cache table in the SQLite database if it doesn't exist.
        """
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute(
            """
            CREATE TABLE IF NOT EXISTS cache (
                url TEXT PRIMARY KEY,
                response BLOB,
                timestamp INTEGER
            )
        """
        )
        conn.commit()
        conn.close()

    async def get(self, url: str) -> Optional[bytes]:
        """
        Retrieves a cached response from the database.

        Args:
            url: The URL to check for a cached response.

        Returns:
            The cached response as bytes if found and not expired, otherwise None.
        """
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute(
            "SELECT response, timestamp FROM cache WHERE url = ?", (url,)
        )
        row = cursor.fetchone()
        conn.close()

        if row:
            response, timestamp = row
            if time.time() - timestamp < self.ttl:
                logger.info(f"Cache hit for {url}")
                return response
            else:
                logger.info(f"Cache expired for {url}")
                return None
        else:
            logger.info(f"Cache miss for {url}")
            return None

    async def set(self, url: str, response: bytes):
        """
        Stores a response in the cache.

        Args:
            url: The URL of the response.
            response: The response as bytes.
        """
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute(
            "INSERT OR REPLACE INTO cache (url, response, timestamp) VALUES (?, ?, ?)",
            (url, response, int(time.time())),
        )
        conn.commit()
        conn.close()

    async def fetch(self, url: str) -> bytes:
        """
        Fetches a response from the upstream server.

        Args:
            url: The URL to fetch.

        Returns:
            The response as bytes.
        """
        try:
            response = await self.client.get(url)
            response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
            return response.content
        except httpx.HTTPError as e:
            logger.error(f"Error fetching {url}: {e}")
            raise HTTPException(status_code=e.response.status_code, detail=str(e))
        except Exception as e:
            logger.error(f"Error fetching {url}: {e}")
            raise HTTPException(status_code=500, detail=str(e))


app = FastAPI()
cache_proxy = None  # type: ignore


@app.api_route("/{path:path}", methods=["GET"])
async def proxy_request(path: str, request: Request):
    """
    Proxies a GET request to the upstream server, using the cache if available.

    Args:
        path: The path to proxy to.
        request: The incoming request.

    Returns:
        The response from the upstream server.
    """
    global cache_proxy
    if cache_proxy is None:
        raise HTTPException(
            status_code=500, detail="Cache proxy not initialized. Check logs."
        )

    url = f"{cache_proxy.upstream_url}/{path}"
    cached_response = await cache_proxy.get(url)

    if cached_response:
        return StreamingResponse(cached_response, media_type="application/octet-stream")

    try:
        response = await cache_proxy.fetch(url)
        await cache_proxy.set(url, response)
        return StreamingResponse(response, media_type="application/octet-stream")
    except HTTPException as e:
        raise e
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error")


async def main():
    """
    Main function to initialize and run the FastAPI application.
    """
    global cache_proxy
    cache_proxy = CacheProxy(
        upstream_url=args.upstream_url, db_path=args.db_path, ttl=args.ttl
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="FastAPI reverse proxy with caching."
    )
    parser.add_argument(
        "--upstream-url", required=True, help="The URL of the upstream server."
    )
    parser.add_argument(
        "--db-path", default="cache.db", help="The path to the SQLite database."
    )
    parser.add_argument(
        "--ttl", type=int, default=3600, help="The cache TTL in seconds."
    )
    args = parser.parse_args()

    import uvicorn

    asyncio.run(main())
    uvicorn.run(app, host="0.0.0.0", port=8000)  # Adjust host and port as needed