package middleware import ( "fmt" "net/http" "strings" ) // SecurityHeaders adds security headers to responses func SecurityHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // HSTS - HTTP Strict Transport Security w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") // CSRF Protection - same-site cookies w.Header().Set("X-Content-Type-Options", "nosniff") // XSS Protection w.Header().Set("X-XSS-Protection", "1; mode=block") // Clickjacking protection w.Header().Set("X-Frame-Options", "DENY") // CSP - Content Security Policy w.Header().Set("Content-Security-Policy", "default-src 'self'; "+ "script-src 'self' 'unsafe-inline'; "+ "style-src 'self' 'unsafe-inline'; "+ "img-src 'self' data: https:; "+ "font-src 'self'; "+ "connect-src 'self'; "+ "frame-ancestors 'none'") // Referrer Policy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") next.ServeHTTP(w, r) }) } // RateLimitMiddleware implements basic rate limiting type RateLimitMiddleware struct { // In production, use a proper rate limiter like github.com/go-chi/chi/middleware // This is a placeholder for demonstration } // NewRateLimitMiddleware creates a new rate limit middleware func NewRateLimitMiddleware() *RateLimitMiddleware { return &RateLimitMiddleware{} } // Middleware returns the rate limit middleware handler func (m *RateLimitMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // TODO: Implement proper rate limiting using distributed cache // For now, this is a placeholder next.ServeHTTP(w, r) }) } // LoggingMiddleware logs HTTP requests func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Printf("[%s] %s %s\n", r.Method, r.RequestURI, r.RemoteAddr) next.ServeHTTP(w, r) }) } // CORSMiddleware enables CORS func CORSMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") switch { case origin == "http://localhost", origin == "http://localhost:5173", strings.HasPrefix(origin, "http://127.0.0.1:"): w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") default: w.Header().Set("Access-Control-Allow-Origin", "http://localhost") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") w.Header().Set("Access-Control-Max-Age", "600") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) }