package middleware import ( "context" "errors" "net/http" "strings" "gitea.hostxtra.co.uk/mrhid6/notely/backend/internal/infrastructure/auth" ) // ContextKey is a custom type for context keys type ContextKey string const ( UserIDKey ContextKey = "user_id" EmailKey ContextKey = "email" UserKey ContextKey = "user" ) // AuthMiddleware verifies JWT tokens type AuthMiddleware struct { jwtManager *auth.JWTManager sessionManager *auth.SessionManager } // NewAuthMiddleware creates a new auth middleware func NewAuthMiddleware(jwtManager *auth.JWTManager, sessionManager *auth.SessionManager) *AuthMiddleware { return &AuthMiddleware{ jwtManager: jwtManager, sessionManager: sessionManager, } } // Middleware wraps an HTTP handler with authentication func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip auth for login and register endpoints if strings.HasSuffix(r.URL.Path, "/auth/login") || strings.HasSuffix(r.URL.Path, "/auth/register") || strings.HasSuffix(r.URL.Path, "/health") { next.ServeHTTP(w, r) return } if sessionCookie, err := r.Cookie("session_id"); err == nil && sessionCookie.Value != "" { sessionData, sessionErr := m.sessionManager.GetSession(r.Context(), sessionCookie.Value) if sessionErr == nil { _ = m.sessionManager.RefreshSession(r.Context(), sessionCookie.Value) ctx := context.WithValue(r.Context(), UserIDKey, sessionData.UserID) ctx = context.WithValue(ctx, EmailKey, sessionData.Email) r = r.WithContext(ctx) next.ServeHTTP(w, r) return } } // Fall back to Authorization header for backwards compatibility. authHeader := r.Header.Get("Authorization") if authHeader == "" { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { http.Error(w, "Invalid authorization header format", http.StatusUnauthorized) return } token := parts[1] // Verify token claims, err := m.jwtManager.VerifyAccessToken(token) if err != nil { http.Error(w, "Invalid token: "+err.Error(), http.StatusUnauthorized) return } // Add claims to context ctx := context.WithValue(r.Context(), UserIDKey, claims.UserID) ctx = context.WithValue(ctx, EmailKey, claims.Email) r = r.WithContext(ctx) next.ServeHTTP(w, r) }) } // GetUserIDFromContext extracts user ID from context func GetUserIDFromContext(ctx context.Context) (string, error) { userID, ok := ctx.Value(UserIDKey).(string) if !ok { return "", errors.New("user ID not found in context") } return userID, nil } // GetEmailFromContext extracts email from context func GetEmailFromContext(ctx context.Context) (string, error) { email, ok := ctx.Value(EmailKey).(string) if !ok { return "", errors.New("email not found in context") } return email, nil }