package middleware import ( "context" "errors" "net/http" "strings" "github.com/noteapp/backend/internal/infrastructure/auth" ) // ContextKey is a custom type for context keys type ContextKey string const ( UserIDKey ContextKey = "user_id" EmailKey ContextKey = "email" UserKey ContextKey = "user" ) // AuthMiddleware verifies JWT tokens type AuthMiddleware struct { jwtManager *auth.JWTManager } // NewAuthMiddleware creates a new auth middleware func NewAuthMiddleware(jwtManager *auth.JWTManager) *AuthMiddleware { return &AuthMiddleware{ jwtManager: jwtManager, } } // Middleware wraps an HTTP handler with authentication func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip auth for login and register endpoints if strings.HasSuffix(r.URL.Path, "/auth/login") || strings.HasSuffix(r.URL.Path, "/auth/register") || strings.HasSuffix(r.URL.Path, "/health") { next.ServeHTTP(w, r) return } // Extract token from Authorization header. // For GET /files/object, also accept ?token= so markdown images render in-browser. authHeader := r.Header.Get("Authorization") if authHeader == "" && r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/files/object") { if tok := r.URL.Query().Get("token"); tok != "" { authHeader = "Bearer " + tok } } if authHeader == "" { http.Error(w, "Missing authorization header", http.StatusUnauthorized) return } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { http.Error(w, "Invalid authorization header format", http.StatusUnauthorized) return } token := parts[1] // Verify token claims, err := m.jwtManager.VerifyAccessToken(token) if err != nil { http.Error(w, "Invalid token: "+err.Error(), http.StatusUnauthorized) return } // Add claims to context ctx := context.WithValue(r.Context(), UserIDKey, claims.UserID) ctx = context.WithValue(ctx, EmailKey, claims.Email) r = r.WithContext(ctx) next.ServeHTTP(w, r) }) } // GetUserIDFromContext extracts user ID from context func GetUserIDFromContext(ctx context.Context) (string, error) { userID, ok := ctx.Value(UserIDKey).(string) if !ok { return "", errors.New("user ID not found in context") } return userID, nil } // GetEmailFromContext extracts email from context func GetEmailFromContext(ctx context.Context) (string, error) { email, ok := ctx.Value(EmailKey).(string) if !ok { return "", errors.New("email not found in context") } return email, nil }