Files
notely/backend/internal/interfaces/middleware/auth.go
domrichardson 6e642da57a
All checks were successful
Build and Push App Image / build-and-push (push) Successful in 1m27s
fix: fixes to session storage
2026-03-26 10:06:07 +00:00

107 lines
2.9 KiB
Go

package middleware
import (
"context"
"errors"
"net/http"
"strings"
"github.com/noteapp/backend/internal/infrastructure/auth"
)
// ContextKey is a custom type for context keys
type ContextKey string
const (
UserIDKey ContextKey = "user_id"
EmailKey ContextKey = "email"
UserKey ContextKey = "user"
)
// AuthMiddleware verifies JWT tokens
type AuthMiddleware struct {
jwtManager *auth.JWTManager
sessionManager *auth.SessionManager
}
// NewAuthMiddleware creates a new auth middleware
func NewAuthMiddleware(jwtManager *auth.JWTManager, sessionManager *auth.SessionManager) *AuthMiddleware {
return &AuthMiddleware{
jwtManager: jwtManager,
sessionManager: sessionManager,
}
}
// Middleware wraps an HTTP handler with authentication
func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip auth for login and register endpoints
if strings.HasSuffix(r.URL.Path, "/auth/login") ||
strings.HasSuffix(r.URL.Path, "/auth/register") ||
strings.HasSuffix(r.URL.Path, "/health") {
next.ServeHTTP(w, r)
return
}
if sessionCookie, err := r.Cookie("session_id"); err == nil && sessionCookie.Value != "" {
sessionData, sessionErr := m.sessionManager.GetSession(r.Context(), sessionCookie.Value)
if sessionErr == nil {
_ = m.sessionManager.RefreshSession(r.Context(), sessionCookie.Value)
ctx := context.WithValue(r.Context(), UserIDKey, sessionData.UserID)
ctx = context.WithValue(ctx, EmailKey, sessionData.Email)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
return
}
}
// Fall back to Authorization header for backwards compatibility.
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
http.Error(w, "Invalid authorization header format", http.StatusUnauthorized)
return
}
token := parts[1]
// Verify token
claims, err := m.jwtManager.VerifyAccessToken(token)
if err != nil {
http.Error(w, "Invalid token: "+err.Error(), http.StatusUnauthorized)
return
}
// Add claims to context
ctx := context.WithValue(r.Context(), UserIDKey, claims.UserID)
ctx = context.WithValue(ctx, EmailKey, claims.Email)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
// GetUserIDFromContext extracts user ID from context
func GetUserIDFromContext(ctx context.Context) (string, error) {
userID, ok := ctx.Value(UserIDKey).(string)
if !ok {
return "", errors.New("user ID not found in context")
}
return userID, nil
}
// GetEmailFromContext extracts email from context
func GetEmailFromContext(ctx context.Context) (string, error) {
email, ok := ctx.Value(EmailKey).(string)
if !ok {
return "", errors.New("email not found in context")
}
return email, nil
}