From 7c3107375707e8372bf0812babdc59d1040e35bd Mon Sep 17 00:00:00 2001 From: Robert Terhaar Date: Mon, 23 Jun 2025 14:17:59 -0500 Subject: [PATCH 1/3] add middleware edit for req headers --- examples/headers_middleware/main.go | 330 ++++++++++++++++++ examples/headers_middleware/main_test.go | 202 +++++++++++ .../httpserver/middleware/headers/options.go | 90 ++++- 3 files changed, 615 insertions(+), 7 deletions(-) create mode 100644 examples/headers_middleware/main.go create mode 100644 examples/headers_middleware/main_test.go diff --git a/examples/headers_middleware/main.go b/examples/headers_middleware/main.go new file mode 100644 index 0000000..cdab3aa --- /dev/null +++ b/examples/headers_middleware/main.go @@ -0,0 +1,330 @@ +// Package main demonstrates request and response header manipulation using +// go-supervisor's headers middleware. +// +// # Header Manipulation Example +// +// This example shows how to use the new request header manipulation features +// alongside existing response header functionality: +// +// 1. Request Header Operations: +// - Remove potentially sensitive headers (X-Forwarded-For) +// - Add custom request headers (X-Internal-Request) +// - Set specific request headers (X-Request-Source) +// +// 2. Response Header Operations: +// - Add security headers (X-Frame-Options) +// - Set custom API headers (X-API-Version) +// - Remove server identification headers (Server) +// +// 3. Header Inspection Route: +// - Displays all request headers received by the handler +// - Shows how request headers are modified before reaching handlers +// - Demonstrates response headers are added after handler execution +// +// The middleware processes headers in this order: +// Request: remove → set → add (before calling handler) +// Response: remove → set → add (after handler returns) +package main + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "os" + "sort" + "time" + + "github.com/robbyt/go-supervisor/runnables/httpserver" + "github.com/robbyt/go-supervisor/runnables/httpserver/middleware/headers" + "github.com/robbyt/go-supervisor/runnables/httpserver/middleware/logger" + "github.com/robbyt/go-supervisor/runnables/httpserver/middleware/metrics" + "github.com/robbyt/go-supervisor/runnables/httpserver/middleware/recovery" + "github.com/robbyt/go-supervisor/supervisor" +) + +const ( + // Port the HTTP server binds to + ListenOn = ":8082" + + // How long the supervisor waits for the HTTP server to drain before forcefully shutting down + DrainTimeout = 5 * time.Second +) + +// HeaderInfo represents header information for JSON response +type HeaderInfo struct { + Name string `json:"name"` + Values []string `json:"values"` +} + +// HeaderResponse represents the complete header inspection response +type HeaderResponse struct { + Message string `json:"message"` + RequestHeaders []HeaderInfo `json:"request_headers"` + ResponseHeaders []HeaderInfo `json:"response_headers,omitempty"` + Timestamp string `json:"timestamp"` +} + +// buildRoutes sets up HTTP routes demonstrating header manipulation +func buildRoutes(logHandler slog.Handler) ([]httpserver.Route, error) { + // Create base middleware stack + recoveryMw := recovery.New(logHandler.WithGroup("recovery")) + loggingMw := logger.New(logHandler.WithGroup("headers_example")) + metricsMw := metrics.New() + + // Create comprehensive headers middleware that demonstrates both + // request and response header manipulation + headersMw := headers.NewWithOperations( + // Request header operations (applied before handler) + headers.WithRemoveRequest("X-Forwarded-For", "X-Real-IP"), // Remove proxy headers + headers.WithSetRequest(headers.HeaderMap{ + "X-Request-Source": "go-supervisor-example", // Set request source + }), + headers.WithAddRequest(headers.HeaderMap{ + "X-Internal-Request": "true", // Mark as internal + }), + headers.WithAddRequestHeader("X-Processing-Time", time.Now().Format(time.RFC3339)), + + // Response header operations (applied after handler) + headers.WithRemove("Server", "X-Powered-By"), // Remove server identification + headers.WithSet(headers.HeaderMap{ + "X-Frame-Options": "DENY", // Security header + "X-API-Version": "v1.0", // API version + "Content-Type": "application/json", // JSON responses + }), + headers.WithAdd(headers.HeaderMap{ + "X-Custom-Header": "go-supervisor-headers", // Custom header + }), + headers.WithAddHeader("X-Response-Time", time.Now().Format(time.RFC3339)), + ) + + // Common middleware stack + commonMw := []httpserver.HandlerFunc{ + recoveryMw, + loggingMw, + metricsMw, + headersMw, // Headers middleware processes both request and response + } + + // Header inspection handler - shows request headers received after middleware processing + headerInspectionHandler := func(w http.ResponseWriter, r *http.Request) { + // Collect request headers + var requestHeaders []HeaderInfo + for name, values := range r.Header { + requestHeaders = append(requestHeaders, HeaderInfo{ + Name: name, + Values: values, + }) + } + + // Sort headers for consistent output + sort.Slice(requestHeaders, func(i, j int) bool { + return requestHeaders[i].Name < requestHeaders[j].Name + }) + + response := HeaderResponse{ + Message: "Header inspection complete - request headers show middleware effects", + RequestHeaders: requestHeaders, + Timestamp: time.Now().Format(time.RFC3339), + } + + // Response headers will be added by middleware after this handler returns + jsonData, err := json.MarshalIndent(response, "", " ") + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + _, err = fmt.Fprint(w, string(jsonData)) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + return + } + } + + // Simple response handler + simpleHandler := func(w http.ResponseWriter, r *http.Request) { + response := map[string]interface{}{ + "message": "Simple response with header manipulation", + "timestamp": time.Now().Format(time.RFC3339), + "note": "Check response headers - they've been modified by middleware", + } + jsonData, err := json.MarshalIndent(response, "", " ") + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + _, err = fmt.Fprint(w, string(jsonData)) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + return + } + } + + // Create routes + headerRoute, err := httpserver.NewRouteFromHandlerFunc( + "header-inspection", + "/headers", + headerInspectionHandler, + commonMw..., + ) + if err != nil { + return nil, fmt.Errorf("failed to create header inspection route: %w", err) + } + + simpleRoute, err := httpserver.NewRouteFromHandlerFunc( + "simple", + "/", + simpleHandler, + commonMw..., + ) + if err != nil { + return nil, fmt.Errorf("failed to create simple route: %w", err) + } + + // Route with different header middleware for comparison + differentHeadersMw := headers.NewWithOperations( + // Different request operations + headers.WithSetRequestHeader("X-Route-Specific", "different-route"), + headers.WithRemoveRequest("User-Agent"), + + // Different response operations + headers.WithSetHeader("X-Route-Type", "special"), + headers.WithAddHeader("X-Different-Header", "route-specific-value"), + ) + + specialMw := []httpserver.HandlerFunc{ + recoveryMw, + loggingMw, + metricsMw, + differentHeadersMw, // Different headers middleware + } + + specialHandler := func(w http.ResponseWriter, r *http.Request) { + // Show headers for this route + var requestHeaders []HeaderInfo + for name, values := range r.Header { + requestHeaders = append(requestHeaders, HeaderInfo{ + Name: name, + Values: values, + }) + } + + sort.Slice(requestHeaders, func(i, j int) bool { + return requestHeaders[i].Name < requestHeaders[j].Name + }) + + response := HeaderResponse{ + Message: "Special route with different header middleware", + RequestHeaders: requestHeaders, + Timestamp: time.Now().Format(time.RFC3339), + } + + w.Header().Set("Content-Type", "application/json") + jsonData, err := json.MarshalIndent(response, "", " ") + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + _, err = fmt.Fprint(w, string(jsonData)) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + return + } + } + + specialRoute, err := httpserver.NewRouteFromHandlerFunc( + "special", + "/special", + specialHandler, + specialMw..., + ) + if err != nil { + return nil, fmt.Errorf("failed to create special route: %w", err) + } + + return httpserver.Routes{*simpleRoute, *headerRoute, *specialRoute}, nil +} + +// createHTTPServer creates the HTTP server with configured routes +func createHTTPServer( + routes []httpserver.Route, + logHandler slog.Handler, +) (*httpserver.Runner, error) { + config, err := httpserver.NewConfig(ListenOn, routes, httpserver.WithDrainTimeout(DrainTimeout)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP server config: %w", err) + } + + return httpserver.NewRunner( + httpserver.WithConfig(config), + httpserver.WithLogHandler(logHandler), + ) +} + +// createSupervisor initializes go-supervisor with the HTTP server +func createSupervisor( + ctx context.Context, + logHandler slog.Handler, + runnable supervisor.Runnable, +) (*supervisor.PIDZero, error) { + sv, err := supervisor.New( + supervisor.WithContext(ctx), + supervisor.WithLogHandler(logHandler), + supervisor.WithRunnables(runnable)) + if err != nil { + return nil, fmt.Errorf("failed to create supervisor: %w", err) + } + + return sv, nil +} + +func main() { + // Configure the logger + handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + slog.SetDefault(slog.New(handler)) + + // Create base context + ctx := context.Background() + + // Build routes + routes, err := buildRoutes(handler) + if err != nil { + slog.Error("Failed to build routes", "error", err) + os.Exit(1) + } + + // Create the HTTP server + httpServer, err := createHTTPServer(routes, handler) + if err != nil { + slog.Error("Failed to create HTTP server", "error", err) + os.Exit(1) + } + + // Create the supervisor + sv, err := createSupervisor(ctx, handler, httpServer) + if err != nil { + slog.Error("Failed to setup server", "error", err) + os.Exit(1) + } + + // Start the supervisor + slog.Info("Starting headers middleware example server", "listen", ListenOn) + slog.Info("Available endpoints:") + slog.Info(" GET / - Simple response with header manipulation") + slog.Info(" GET /headers - Inspect request headers (shows middleware effects)") + slog.Info(" GET /special - Route with different header middleware") + slog.Info("") + slog.Info("Test with curl:") + slog.Info( + ` curl -H "X-Forwarded-For: 192.168.1.1" -H "User-Agent: test" -v http://localhost:8082/headers`, + ) + slog.Info(" Notice: X-Forwarded-For is removed, custom headers are added") + + if err := sv.Run(); err != nil { + slog.Error("Supervisor failed", "error", err) + os.Exit(1) + } +} diff --git a/examples/headers_middleware/main_test.go b/examples/headers_middleware/main_test.go new file mode 100644 index 0000000..4c23c70 --- /dev/null +++ b/examples/headers_middleware/main_test.go @@ -0,0 +1,202 @@ +package main + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestHeadersMiddlewareExample tests the headers middleware functionality +func TestHeadersMiddlewareExample(t *testing.T) { + t.Parallel() + + // Create a test logger that discards output + logHandler := slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}) + + // Create a context with timeout for the test + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Build routes + routes, err := buildRoutes(logHandler) + require.NoError(t, err, "Failed to build routes") + require.NotEmpty(t, routes, "Routes should not be empty") + + // Create the HTTP server + httpServer, err := createHTTPServer(routes, logHandler) + require.NoError(t, err, "Failed to create HTTP server") + require.NotNil(t, httpServer, "HTTP server should not be nil") + + // Create the supervisor + sv, err := createSupervisor(ctx, logHandler, httpServer) + require.NoError(t, err, "Failed to create supervisor") + require.NotNil(t, sv, "Supervisor should not be nil") + + // Start the server in a goroutine + errCh := make(chan error, 1) + go func() { + errCh <- sv.Run() + }() + + // Give the server time to start + time.Sleep(200 * time.Millisecond) + + t.Run("TestSimpleRoute", func(t *testing.T) { + // Test simple route + resp, err := http.Get("http://localhost:8082/") + require.NoError(t, err, "Failed to make GET request") + defer func() { + assert.NoError(t, resp.Body.Close()) + }() + + // Check response headers are set by middleware + assert.Equal(t, "DENY", resp.Header.Get("X-Frame-Options")) + assert.Equal(t, "v1.0", resp.Header.Get("X-API-Version")) + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + assert.Equal(t, "go-supervisor-headers", resp.Header.Get("X-Custom-Header")) + assert.NotEmpty(t, resp.Header.Get("X-Response-Time")) + + // Check that server identification headers are removed + assert.Empty(t, resp.Header.Get("Server")) + assert.Empty(t, resp.Header.Get("X-Powered-By")) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("TestHeaderInspectionRoute", func(t *testing.T) { + // Create request with headers that should be modified + req, err := http.NewRequest("GET", "http://localhost:8082/headers", nil) + require.NoError(t, err) + + // Add headers that should be removed by middleware + req.Header.Set("X-Forwarded-For", "192.168.1.1") + req.Header.Set("X-Real-IP", "10.0.0.1") + req.Header.Set("User-Agent", "test-agent") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err, "Failed to make request") + defer func() { + assert.NoError(t, resp.Body.Close()) + }() + + // Check response headers + assert.Equal(t, "DENY", resp.Header.Get("X-Frame-Options")) + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + // Read and parse response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "Failed to read response body") + + var headerResponse HeaderResponse + err = json.Unmarshal(body, &headerResponse) + require.NoError(t, err, "Failed to unmarshal response") + + // Check that request headers were properly modified + headerMap := make(map[string][]string) + for _, h := range headerResponse.RequestHeaders { + headerMap[h.Name] = h.Values + } + + // Headers that should be removed + assert.NotContains(t, headerMap, "X-Forwarded-For", "X-Forwarded-For should be removed") + assert.NotContains(t, headerMap, "X-Real-IP", "X-Real-IP should be removed") + + // Headers that should be added/set by middleware + assert.Contains(t, headerMap, "X-Request-Source") + assert.Equal(t, []string{"go-supervisor-example"}, headerMap["X-Request-Source"]) + assert.Contains(t, headerMap, "X-Internal-Request") + assert.Equal(t, []string{"true"}, headerMap["X-Internal-Request"]) + assert.Contains(t, headerMap, "X-Processing-Time") + + // User-Agent should still be present (not removed in main middleware) + assert.Contains(t, headerMap, "User-Agent") + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("TestSpecialRoute", func(t *testing.T) { + // Create request with User-Agent that should be removed by special route middleware + req, err := http.NewRequest("GET", "http://localhost:8082/special", nil) + require.NoError(t, err) + req.Header.Set("User-Agent", "test-agent") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err, "Failed to make request") + defer func() { + assert.NoError(t, resp.Body.Close()) + }() + + // Check special route response headers + assert.Equal(t, "special", resp.Header.Get("X-Route-Type")) + assert.Equal(t, "route-specific-value", resp.Header.Get("X-Different-Header")) + + // Read and parse response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err, "Failed to read response body") + + var headerResponse HeaderResponse + err = json.Unmarshal(body, &headerResponse) + require.NoError(t, err, "Failed to unmarshal response") + + // Check that User-Agent was removed by special route middleware + headerMap := make(map[string][]string) + for _, h := range headerResponse.RequestHeaders { + headerMap[h.Name] = h.Values + } + + assert.NotContains( + t, + headerMap, + "User-Agent", + "User-Agent should be removed by special route", + ) + assert.Contains(t, headerMap, "X-Route-Specific") + assert.Equal(t, []string{"different-route"}, headerMap["X-Route-Specific"]) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + // Stop the supervisor + sv.Shutdown() + + // Wait for server to stop or timeout + select { + case err := <-errCh: + // This is expected when shutdown is called + if err != nil { + t.Logf("Server stopped with: %v", err) + } + case <-time.After(2 * time.Second): + t.Error("Server did not stop within timeout") + } +} + +// TestBuildRoutes tests route building in isolation +func TestBuildRoutes(t *testing.T) { + t.Parallel() + + logHandler := slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}) + + routes, err := buildRoutes(logHandler) + require.NoError(t, err, "buildRoutes should not return an error") + require.Len(t, routes, 3, "Should have exactly 3 routes") + + // Check route paths since route names are not publicly accessible + routePaths := make([]string, len(routes)) + for i, route := range routes { + routePaths[i] = route.Path + } + + assert.Contains(t, routePaths, "/") + assert.Contains(t, routePaths, "/headers") + assert.Contains(t, routePaths, "/special") +} diff --git a/runnables/httpserver/middleware/headers/options.go b/runnables/httpserver/middleware/headers/options.go index 4e4d16e..195dad5 100644 --- a/runnables/httpserver/middleware/headers/options.go +++ b/runnables/httpserver/middleware/headers/options.go @@ -10,9 +10,12 @@ import ( type HeaderOperation func(*headerOperations) type headerOperations struct { - setHeaders http.Header - addHeaders http.Header - removeHeaders []string + setHeaders http.Header + addHeaders http.Header + removeHeaders []string + setRequestHeaders http.Header + addRequestHeaders http.Header + removeRequestHeaders []string } // WithSet creates an operation to set (replace) headers @@ -66,8 +69,59 @@ func WithRemove(headerNames ...string) HeaderOperation { } } +// WithSetRequest creates an operation to set (replace) request headers +func WithSetRequest(headers HeaderMap) HeaderOperation { + return func(ops *headerOperations) { + if ops.setRequestHeaders == nil { + ops.setRequestHeaders = make(http.Header) + } + for key, value := range headers { + ops.setRequestHeaders.Set(key, value) + } + } +} + +// WithSetRequestHeader creates an operation to set a single request header +func WithSetRequestHeader(key, value string) HeaderOperation { + return func(ops *headerOperations) { + if ops.setRequestHeaders == nil { + ops.setRequestHeaders = make(http.Header) + } + ops.setRequestHeaders.Set(key, value) + } +} + +// WithAddRequest creates an operation to add (append) request headers +func WithAddRequest(headers HeaderMap) HeaderOperation { + return func(ops *headerOperations) { + if ops.addRequestHeaders == nil { + ops.addRequestHeaders = make(http.Header) + } + for key, value := range headers { + ops.addRequestHeaders.Add(key, value) + } + } +} + +// WithAddRequestHeader creates an operation to add a single request header +func WithAddRequestHeader(key, value string) HeaderOperation { + return func(ops *headerOperations) { + if ops.addRequestHeaders == nil { + ops.addRequestHeaders = make(http.Header) + } + ops.addRequestHeaders.Add(key, value) + } +} + +// WithRemoveRequest creates an operation to remove request headers +func WithRemoveRequest(headerNames ...string) HeaderOperation { + return func(ops *headerOperations) { + ops.removeRequestHeaders = append(ops.removeRequestHeaders, headerNames...) + } +} + // NewWithOperations creates a middleware with full header control using functional options. -// Operations are executed in order: remove → set → add +// Operations are executed in order: remove → set → add (for both request and response headers) func NewWithOperations(operations ...HeaderOperation) httpserver.HandlerFunc { ops := &headerOperations{} for _, operation := range operations { @@ -75,21 +129,43 @@ func NewWithOperations(operations ...HeaderOperation) httpserver.HandlerFunc { } return func(rp *httpserver.RequestProcessor) { + request := rp.Request() writer := rp.Writer() - // 1. Remove headers first + // Request header manipulation (before calling Next) + // 1. Remove request headers first + for _, key := range ops.removeRequestHeaders { + request.Header.Del(key) + } + + // 2. Set request headers (replace) + for key, values := range ops.setRequestHeaders { + if len(values) > 0 { + request.Header.Set(key, values[0]) + } + } + + // 3. Add request headers (append) + for key, values := range ops.addRequestHeaders { + for _, value := range values { + request.Header.Add(key, value) + } + } + + // Response header manipulation (existing functionality) + // 1. Remove response headers first for _, key := range ops.removeHeaders { writer.Header().Del(key) } - // 2. Set headers (replace) + // 2. Set response headers (replace) for key, values := range ops.setHeaders { if len(values) > 0 { writer.Header().Set(key, values[0]) } } - // 3. Add headers (append) + // 3. Add response headers (append) for key, values := range ops.addHeaders { for _, value := range values { writer.Header().Add(key, value) From 0371f014a11eec2468de78d782475de26347f4f6 Mon Sep 17 00:00:00 2001 From: Robert Terhaar Date: Mon, 23 Jun 2025 14:41:51 -0500 Subject: [PATCH 2/3] remove the sort from the examples --- examples/headers_middleware/main.go | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/examples/headers_middleware/main.go b/examples/headers_middleware/main.go index cdab3aa..93d1c9f 100644 --- a/examples/headers_middleware/main.go +++ b/examples/headers_middleware/main.go @@ -1,5 +1,4 @@ -// Package main demonstrates request and response header manipulation using -// go-supervisor's headers middleware. +// Example of request/response header manipulation using a middleware. // // # Header Manipulation Example // @@ -33,7 +32,6 @@ import ( "log/slog" "net/http" "os" - "sort" "time" "github.com/robbyt/go-supervisor/runnables/httpserver" @@ -118,11 +116,6 @@ func buildRoutes(logHandler slog.Handler) ([]httpserver.Route, error) { }) } - // Sort headers for consistent output - sort.Slice(requestHeaders, func(i, j int) bool { - return requestHeaders[i].Name < requestHeaders[j].Name - }) - response := HeaderResponse{ Message: "Header inspection complete - request headers show middleware effects", RequestHeaders: requestHeaders, @@ -144,7 +137,7 @@ func buildRoutes(logHandler slog.Handler) ([]httpserver.Route, error) { // Simple response handler simpleHandler := func(w http.ResponseWriter, r *http.Request) { - response := map[string]interface{}{ + response := map[string]any{ "message": "Simple response with header manipulation", "timestamp": time.Now().Format(time.RFC3339), "note": "Check response headers - they've been modified by middleware", @@ -210,10 +203,6 @@ func buildRoutes(logHandler slog.Handler) ([]httpserver.Route, error) { }) } - sort.Slice(requestHeaders, func(i, j int) bool { - return requestHeaders[i].Name < requestHeaders[j].Name - }) - response := HeaderResponse{ Message: "Special route with different header middleware", RequestHeaders: requestHeaders, From 56552477dbf2cc95a793846bf5502680ee897f6c Mon Sep 17 00:00:00 2001 From: Robert Terhaar Date: Mon, 23 Jun 2025 14:53:32 -0500 Subject: [PATCH 3/3] add unit tests for the request header editor --- .../headers/options_request_test.go | 526 ++++++++++++++++++ 1 file changed, 526 insertions(+) create mode 100644 runnables/httpserver/middleware/headers/options_request_test.go diff --git a/runnables/httpserver/middleware/headers/options_request_test.go b/runnables/httpserver/middleware/headers/options_request_test.go new file mode 100644 index 0000000..27cf846 --- /dev/null +++ b/runnables/httpserver/middleware/headers/options_request_test.go @@ -0,0 +1,526 @@ +package headers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/robbyt/go-supervisor/runnables/httpserver" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequestHeaderOperations(t *testing.T) { + t.Parallel() + + t.Run("WithSetRequest sets request headers", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithSetRequest(HeaderMap{ + "X-Request-ID": "req-123", + "X-Forwarded-Host": "example.com", + }), + ) + + req := httptest.NewRequest("GET", "/test", nil) + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + assert.Equal(t, "req-123", capturedHeaders.Get("X-Request-ID")) + assert.Equal(t, "example.com", capturedHeaders.Get("X-Forwarded-Host")) + }) + + t.Run("WithSetRequestHeader sets single request header", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithSetRequestHeader("X-Custom", "test-value"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + assert.Equal(t, "test-value", capturedHeaders.Get("X-Custom")) + }) + + t.Run("WithAddRequest adds request headers", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithAddRequest(HeaderMap{ + "X-Tags": "tag1", + "X-Meta": "info", + }), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Tags", "existing-tag") + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + tags := capturedHeaders.Values("X-Tags") + assert.Len(t, tags, 2) + assert.Contains(t, tags, "existing-tag") + assert.Contains(t, tags, "tag1") + assert.Equal(t, "info", capturedHeaders.Get("X-Meta")) + }) + + t.Run("WithAddRequestHeader adds single request header", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithAddRequestHeader("Accept-Encoding", "gzip"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Accept-Encoding", "deflate") + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + encodings := capturedHeaders.Values("Accept-Encoding") + assert.Len(t, encodings, 2) + assert.Contains(t, encodings, "deflate") + assert.Contains(t, encodings, "gzip") + }) + + t.Run("WithRemoveRequest removes request headers", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithRemoveRequest("X-Forwarded-For", "X-Real-IP"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-For", "192.168.1.1") + req.Header.Set("X-Real-IP", "10.0.0.1") + req.Header.Set("User-Agent", "test-agent") + req.Header.Set("Authorization", "Bearer token123") + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + assert.Empty(t, capturedHeaders.Get("X-Forwarded-For")) + assert.Empty(t, capturedHeaders.Get("X-Real-IP")) + assert.Equal(t, "test-agent", capturedHeaders.Get("User-Agent")) + assert.Equal(t, "Bearer token123", capturedHeaders.Get("Authorization")) + }) + + t.Run("request header operation ordering: remove → set → add", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithRemoveRequest("X-Test"), + WithSetRequest(HeaderMap{"X-Test": "set-value"}), + WithAddRequest(HeaderMap{"X-Test": "add-value"}), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Test", "original-value") + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + values := capturedHeaders.Values("X-Test") + assert.Len(t, values, 2) + assert.Contains(t, values, "set-value") + assert.Contains(t, values, "add-value") + assert.NotContains(t, values, "original-value") + }) + + t.Run("multiple WithSetRequestHeader calls for same key", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithSetRequestHeader("X-Version", "1.0"), + WithSetRequestHeader("X-Version", "2.0"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + assert.Equal(t, "2.0", capturedHeaders.Get("X-Version")) + values := capturedHeaders.Values("X-Version") + assert.Len(t, values, 1) + }) + + t.Run("multiple WithAddRequestHeader calls for same key", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithAddRequestHeader("X-Trace", "trace1"), + WithAddRequestHeader("X-Trace", "trace2"), + WithAddRequestHeader("X-Trace", "trace3"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + traces := capturedHeaders.Values("X-Trace") + assert.Len(t, traces, 3) + assert.Contains(t, traces, "trace1") + assert.Contains(t, traces, "trace2") + assert.Contains(t, traces, "trace3") + }) + + t.Run("WithSetRequestHeader overwrites existing request header", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithSetRequestHeader("User-Agent", "custom-agent/1.0"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("User-Agent", "original-agent") + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + assert.Equal(t, "custom-agent/1.0", capturedHeaders.Get("User-Agent")) + values := capturedHeaders.Values("User-Agent") + assert.Len(t, values, 1) + }) + + t.Run("WithAddRequestHeader appends to existing request header", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithAddRequestHeader("Accept", "application/json"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Accept", "text/html") + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + accepts := capturedHeaders.Values("Accept") + assert.Len(t, accepts, 2) + assert.Contains(t, accepts, "text/html") + assert.Contains(t, accepts, "application/json") + }) + + t.Run("remove non-existent request headers", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithRemoveRequest("Non-Existent-Header"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Existing-Header", "value") + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + assert.Equal(t, "value", capturedHeaders.Get("Existing-Header")) + assert.Empty(t, capturedHeaders.Get("Non-Existent-Header")) + }) + + t.Run("empty request header operations", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithSetRequest(HeaderMap{}), + WithAddRequest(HeaderMap{}), + WithRemoveRequest(), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Existing", "value") + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + assert.Equal(t, "value", capturedHeaders.Get("Existing")) + }) +} + +func TestRequestAndResponseHeaderCombinations(t *testing.T) { + t.Parallel() + + t.Run("request and response headers work together", func(t *testing.T) { + var capturedRequestHeaders http.Header + middleware := NewWithOperations( + // Request operations + WithRemoveRequest("X-Forwarded-For"), + WithSetRequest(HeaderMap{"X-Request-Source": "middleware"}), + WithAddRequest(HeaderMap{"X-Request-Tag": "processed"}), + + // Response operations + WithRemove("Server"), + WithSet(HeaderMap{"X-API-Version": "v1.0"}), + WithAdd(HeaderMap{"X-Response-Tag": "enhanced"}), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-For", "192.168.1.1") + req.Header.Set("User-Agent", "test-agent") + rec := httptest.NewRecorder() + rec.Header().Set("Server", "nginx/1.0") + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedRequestHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + // Verify request header modifications + assert.Empty(t, capturedRequestHeaders.Get("X-Forwarded-For")) + assert.Equal(t, "test-agent", capturedRequestHeaders.Get("User-Agent")) + assert.Equal(t, "middleware", capturedRequestHeaders.Get("X-Request-Source")) + assert.Equal(t, "processed", capturedRequestHeaders.Get("X-Request-Tag")) + + // Verify response header modifications + assert.Empty(t, rec.Header().Get("Server")) + assert.Equal(t, "v1.0", rec.Header().Get("X-API-Version")) + assert.Equal(t, "enhanced", rec.Header().Get("X-Response-Tag")) + }) + + t.Run("same header name in request and response operations", func(t *testing.T) { + var capturedRequestHeaders http.Header + middleware := NewWithOperations( + WithSetRequest(HeaderMap{"X-Version": "request-v1"}), + WithSet(HeaderMap{"X-Version": "response-v1"}), + ) + + req := httptest.NewRequest("GET", "/test", nil) + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedRequestHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + assert.Equal(t, "request-v1", capturedRequestHeaders.Get("X-Version")) + assert.Equal(t, "response-v1", rec.Header().Get("X-Version")) + }) + + t.Run("complex real-world scenario: proxy header cleanup", func(t *testing.T) { + var capturedRequestHeaders http.Header + middleware := NewWithOperations( + // Remove proxy headers from request + WithRemoveRequest("X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto"), + // Add internal tracking headers to request + WithSetRequest(HeaderMap{ + "X-Internal-Request": "true", + "X-Request-ID": "req-12345", + }), + // Add security response headers + WithSet(HeaderMap{ + "X-Frame-Options": "DENY", + "X-Content-Type-Options": "nosniff", + }), + // Remove server identification from response + WithRemove("Server", "X-Powered-By"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-For", "192.168.1.1, 10.0.0.1") + req.Header.Set("X-Real-IP", "192.168.1.1") + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("User-Agent", "Mozilla/5.0") + rec := httptest.NewRecorder() + rec.Header().Set("Server", "nginx/1.18") + rec.Header().Set("X-Powered-By", "Express") + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedRequestHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + // Verify proxy headers removed from request + assert.Empty(t, capturedRequestHeaders.Get("X-Forwarded-For")) + assert.Empty(t, capturedRequestHeaders.Get("X-Real-IP")) + assert.Empty(t, capturedRequestHeaders.Get("X-Forwarded-Proto")) + + // Verify internal headers added to request + assert.Equal(t, "true", capturedRequestHeaders.Get("X-Internal-Request")) + assert.Equal(t, "req-12345", capturedRequestHeaders.Get("X-Request-ID")) + + // Verify original headers preserved in request + assert.Equal(t, "Mozilla/5.0", capturedRequestHeaders.Get("User-Agent")) + + // Verify server identification removed from response + assert.Empty(t, rec.Header().Get("Server")) + assert.Empty(t, rec.Header().Get("X-Powered-By")) + + // Verify security headers added to response + assert.Equal(t, "DENY", rec.Header().Get("X-Frame-Options")) + assert.Equal(t, "nosniff", rec.Header().Get("X-Content-Type-Options")) + }) +} + +func TestRequestHeaderEdgeCases(t *testing.T) { + t.Parallel() + + t.Run("empty request header keys and values", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithSetRequestHeader("", "empty-key"), + WithAddRequestHeader("Empty-Value", ""), + ) + + req := httptest.NewRequest("GET", "/test", nil) + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + assert.Equal(t, "", capturedHeaders.Get("Empty-Value")) + }) + + t.Run("request header case sensitivity", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithSetRequestHeader("content-type", "application/json"), + WithAddRequestHeader("Content-Type", "charset=utf-8"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + values := capturedHeaders.Values("Content-Type") + assert.Len(t, values, 2) + assert.Contains(t, values, "application/json") + assert.Contains(t, values, "charset=utf-8") + }) + + t.Run("remove request headers with non-canonical names", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithRemoveRequest("content-type", "user-agent"), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Content-Type", "text/html") + req.Header.Set("User-Agent", "test-agent") + req.Header.Set("Accept", "text/html") + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + // Go canonicalizes header names for removal + assert.Empty(t, capturedHeaders.Get("Content-Type")) + assert.Empty(t, capturedHeaders.Get("User-Agent")) + assert.Equal(t, "text/html", capturedHeaders.Get("Accept")) + }) + + t.Run("nil request header maps", func(t *testing.T) { + var capturedHeaders http.Header + middleware := NewWithOperations( + WithSetRequest(nil), + WithAddRequest(nil), + ) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Existing", "value") + rec := httptest.NewRecorder() + + route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", + func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + }, middleware) + require.NoError(t, err) + + route.ServeHTTP(rec, req) + + assert.Equal(t, "value", capturedHeaders.Get("Existing")) + }) +}