""" Tests for middleware module """ from unittest.mock import MagicMock, patch import pytest from fastapi import FastAPI, Request, Response from fastapi.testclient import TestClient from middleware.logging_middleware import LoggingMiddleware class TestLoggingMiddleware: """Test LoggingMiddleware class""" @pytest.fixture def app(self): """Create FastAPI app for testing""" app = FastAPI() @app.get("/test") async def test_endpoint(): return {"message": "test"} @app.get("/error") async def error_endpoint(): raise ValueError("Test error") return app @pytest.fixture def middleware(self): """Create LoggingMiddleware instance""" return LoggingMiddleware(MagicMock()) def test_middleware_initialization(self, middleware): """Test middleware initialization""" assert middleware is not None assert hasattr(middleware, "dispatch") @pytest.mark.asyncio async def test_dispatch_successful_request(self, middleware): """Test middleware with successful request""" # Create mock request mock_request = MagicMock(spec=Request) mock_request.method = "GET" mock_request.url.path = "/test" mock_request.client.host = "127.0.0.1" mock_request.headers = {"user-agent": "test-agent"} # Create mock response mock_response = MagicMock(spec=Response) mock_response.status_code = 200 # Create mock call_next async def mock_call_next(request): return mock_response with patch("middleware.logging_middleware.logger") as mock_logger: with patch("time.time", side_effect=[0, 0.5]): # Mock timing result = await middleware.dispatch(mock_request, mock_call_next) assert result == mock_response # Verify logging calls assert mock_logger.info.call_count == 2 # Request start and completion mock_logger.info.assert_any_call("Request started: GET /test from 127.0.0.1 (User-Agent: test-agent)") mock_logger.info.assert_any_call("Request completed: GET /test -> 200 in 0.500s") @pytest.mark.asyncio async def test_dispatch_request_with_exception(self, middleware): """Test middleware with request that raises exception""" # Create mock request mock_request = MagicMock(spec=Request) mock_request.method = "GET" mock_request.url.path = "/error" mock_request.client.host = "127.0.0.1" mock_request.headers = {"user-agent": "test-agent"} # Create mock call_next that raises exception async def mock_call_next(request): raise ValueError("Test error") with patch("middleware.logging_middleware.logger") as mock_logger: with patch("time.time", side_effect=[0, 0.3]): # Mock timing with pytest.raises(ValueError): await middleware.dispatch(mock_request, mock_call_next) # Verify logging calls assert mock_logger.info.call_count == 1 # Request start assert mock_logger.error.call_count == 1 # Error logging mock_logger.info.assert_called_with("Request started: GET /error from 127.0.0.1 (User-Agent: test-agent)") mock_logger.error.assert_called_with("Request failed: GET /error -> Exception: Test error in 0.300s") @pytest.mark.asyncio async def test_dispatch_with_no_client(self, middleware): """Test middleware with request that has no client""" # Create mock request without client mock_request = MagicMock(spec=Request) mock_request.method = "GET" mock_request.url.path = "/test" mock_request.client = None mock_request.headers = {"user-agent": "test-agent"} # Create mock response mock_response = MagicMock(spec=Response) mock_response.status_code = 200 # Create mock call_next async def mock_call_next(request): return mock_response with patch("middleware.logging_middleware.logger") as mock_logger: with patch("time.time", side_effect=[0, 0.5]): result = await middleware.dispatch(mock_request, mock_call_next) assert result == mock_response # Verify logging with "unknown" client mock_logger.info.assert_any_call("Request started: GET /test from unknown (User-Agent: test-agent)") @pytest.mark.asyncio async def test_dispatch_with_no_user_agent(self, middleware): """Test middleware with request that has no user agent""" # Create mock request without user agent mock_request = MagicMock(spec=Request) mock_request.method = "GET" mock_request.url.path = "/test" mock_request.client.host = "127.0.0.1" mock_request.headers = {} # Create mock response mock_response = MagicMock(spec=Response) mock_response.status_code = 200 # Create mock call_next async def mock_call_next(request): return mock_response with patch("middleware.logging_middleware.logger") as mock_logger: with patch("time.time", side_effect=[0, 0.5]): result = await middleware.dispatch(mock_request, mock_call_next) assert result == mock_response # Verify logging with "unknown" user agent mock_logger.info.assert_any_call("Request started: GET /test from 127.0.0.1 (User-Agent: unknown)") @pytest.mark.asyncio async def test_dispatch_timing_accuracy(self, middleware): """Test that timing is calculated correctly""" # Create mock request mock_request = MagicMock(spec=Request) mock_request.method = "POST" mock_request.url.path = "/api/data" mock_request.client.host = "192.168.1.100" mock_request.headers = {"user-agent": "Mozilla/5.0"} # Create mock response mock_response = MagicMock(spec=Response) mock_response.status_code = 201 # Create mock call_next async def mock_call_next(request): return mock_response with patch("middleware.logging_middleware.logger") as mock_logger: with patch("time.time", side_effect=[1000.0, 1000.123]): # 0.123 seconds result = await middleware.dispatch(mock_request, mock_call_next) assert result == mock_response # Verify timing in log message mock_logger.info.assert_any_call("Request completed: POST /api/data -> 201 in 0.123s") @pytest.mark.asyncio async def test_dispatch_different_http_methods(self, middleware): """Test middleware with different HTTP methods""" methods = ["GET", "POST", "PUT", "DELETE", "PATCH"] for method in methods: # Create mock request mock_request = MagicMock(spec=Request) mock_request.method = method mock_request.url.path = f"/{method.lower()}" mock_request.client.host = "127.0.0.1" mock_request.headers = {"user-agent": "test-agent"} # Create mock response mock_response = MagicMock(spec=Response) mock_response.status_code = 200 # Create mock call_next async def mock_call_next(request): return mock_response with patch("middleware.logging_middleware.logger") as mock_logger: with patch("time.time", side_effect=[0, 0.1]): result = await middleware.dispatch(mock_request, mock_call_next) assert result == mock_response # Verify method is logged correctly mock_logger.info.assert_any_call(f"Request started: {method} /{method.lower()} from 127.0.0.1 (User-Agent: test-agent)") @pytest.mark.asyncio async def test_dispatch_different_status_codes(self, middleware): """Test middleware with different status codes""" status_codes = [200, 201, 400, 401, 403, 404, 500, 502, 503] for status_code in status_codes: # Create mock request mock_request = MagicMock(spec=Request) mock_request.method = "GET" mock_request.url.path = "/test" mock_request.client.host = "127.0.0.1" mock_request.headers = {"user-agent": "test-agent"} # Create mock response mock_response = MagicMock(spec=Response) mock_response.status_code = status_code # Create mock call_next async def mock_call_next(request): return mock_response with patch("middleware.logging_middleware.logger") as mock_logger: with patch("time.time", side_effect=[0, 0.1]): result = await middleware.dispatch(mock_request, mock_call_next) assert result == mock_response # Verify status code is logged correctly mock_logger.info.assert_any_call(f"Request completed: GET /test -> {status_code} in 0.100s") def test_middleware_integration_with_fastapi(self, app): """Test middleware integration with FastAPI""" # Add middleware to app app.add_middleware(LoggingMiddleware) # Create test client client = TestClient(app) with patch("middleware.logging_middleware.logger") as mock_logger: response = client.get("/test") assert response.status_code == 200 assert response.json() == {"message": "test"} # Verify middleware was called assert mock_logger.info.call_count >= 2 # At least start and completion def test_middleware_error_handling_integration(self, app): """Test middleware error handling with FastAPI""" # Add middleware to app app.add_middleware(LoggingMiddleware) # Create test client client = TestClient(app) with patch("middleware.logging_middleware.logger") as mock_logger: with pytest.raises(ValueError): client.get("/error") # Verify error was logged assert mock_logger.error.call_count >= 1