package handlers import ( "encoding/json" "net/http" "net/url" "os" "strings" "gitea.hostxtra.co.uk/mrhid6/notely/backend/internal/application/dto" "gitea.hostxtra.co.uk/mrhid6/notely/backend/internal/application/services" "gitea.hostxtra.co.uk/mrhid6/notely/backend/internal/infrastructure/auth" "github.com/gorilla/mux" "go.mongodb.org/mongo-driver/v2/bson" ) // AuthHandler handles authentication endpoints type AuthHandler struct { authService *services.AuthService sessionManager *auth.SessionManager } // NewAuthHandler creates a new auth handler func NewAuthHandler(authService *services.AuthService, sessionManager *auth.SessionManager) *AuthHandler { return &AuthHandler{ authService: authService, sessionManager: sessionManager, } } const sessionCookieName = "session_id" // Register handles user registration func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req dto.RegisterRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } // Basic validation if req.Email == "" || req.Password == "" || req.Username == "" { http.Error(w, "Missing required fields", http.StatusBadRequest) return } response, err := h.authService.Register(r.Context(), &req) if err != nil { if strings.Contains(strings.ToLower(err.Error()), "registration is currently disabled") { http.Error(w, err.Error(), http.StatusForbidden) return } http.Error(w, err.Error(), http.StatusBadRequest) return } if err := h.setSessionCookie(w, r, response.User); err != nil { http.Error(w, "Failed to create session", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // Login handles user login func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req dto.LoginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } response, err := h.authService.Login(r.Context(), &req) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } if err := h.setSessionCookie(w, r, response.User); err != nil { http.Error(w, "Failed to create session", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } // Logout handles user logout func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) { sessionCookie, err := r.Cookie(sessionCookieName) if err == nil { _ = h.sessionManager.DeleteSession(r.Context(), sessionCookie.Value) } h.clearSessionCookie(w, r) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"}) } // ListProviders returns all active OAuth/OIDC providers. func (h *AuthHandler) ListProviders(w http.ResponseWriter, r *http.Request) { providers, err := h.authService.ListProviders(r.Context()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{"providers": providers}) } // ListProvidersForAdmin returns all OAuth/OIDC providers, including inactive ones. func (h *AuthHandler) ListProvidersForAdmin(w http.ResponseWriter, r *http.Request) { providers, err := h.authService.ListProvidersForAdmin(r.Context()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{"providers": providers}) } // CreateProvider stores a new OAuth/OIDC provider configuration. func (h *AuthHandler) CreateProvider(w http.ResponseWriter, r *http.Request) { var req dto.CreateAuthProviderRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } provider, err := h.authService.CreateProvider(r.Context(), &req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) json.NewEncoder(w).Encode(provider) } // UpdateProvider updates an existing OAuth/OIDC provider configuration. func (h *AuthHandler) UpdateProvider(w http.ResponseWriter, r *http.Request) { providerID, err := bson.ObjectIDFromHex(mux.Vars(r)["providerId"]) if err != nil { http.Error(w, "Invalid provider ID", http.StatusBadRequest) return } var req dto.UpdateAuthProviderRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } provider, err := h.authService.UpdateProvider(r.Context(), providerID, &req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(provider) } // StartProviderLogin redirects the browser to the selected provider. func (h *AuthHandler) StartProviderLogin(w http.ResponseWriter, r *http.Request) { providerID, err := bson.ObjectIDFromHex(mux.Vars(r)["providerId"]) if err != nil { http.Error(w, "Invalid provider ID", http.StatusBadRequest) return } state, err := auth.GenerateStateToken() if err != nil { http.Error(w, "Failed to create OAuth state", http.StatusInternalServerError) return } http.SetCookie(w, &http.Cookie{ Name: "oauth_state", Value: state, Path: "/", MaxAge: 10 * 60, HttpOnly: true, Secure: isSecureRequest(r), SameSite: http.SameSiteLaxMode, }) redirectURI := buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback") authorizationURL, err := h.authService.BuildProviderAuthorizationURL(r.Context(), providerID, redirectURI, state) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } http.Redirect(w, r, authorizationURL, http.StatusFound) } // CompleteProviderLogin exchanges the authorization code and redirects back to the frontend. func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Request) { providerID, err := bson.ObjectIDFromHex(mux.Vars(r)["providerId"]) if err != nil { http.Error(w, "Invalid provider ID", http.StatusBadRequest) return } stateCookie, err := r.Cookie("oauth_state") if err != nil || stateCookie.Value == "" || stateCookie.Value != r.URL.Query().Get("state") { http.Error(w, "Invalid OAuth state", http.StatusBadRequest) return } response, err := h.authService.CompleteProviderLogin(r.Context(), providerID, r.URL.Query().Get("code"), buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback")) if err != nil { http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error()), http.StatusFound) return } http.SetCookie(w, &http.Cookie{ Name: "oauth_state", Value: "", Path: "/", MaxAge: -1, HttpOnly: true, Secure: isSecureRequest(r), SameSite: http.SameSiteLaxMode, }) if err := h.setSessionCookie(w, r, response.User); err != nil { http.Redirect(w, r, buildFrontendLoginURL("oauth_error", "Failed to create session"), http.StatusFound) return } http.Redirect(w, r, buildFrontendLoginURL("oauth_success", ""), http.StatusFound) } // RefreshToken handles token refresh func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } cookie, err := r.Cookie(sessionCookieName) if err != nil { http.Error(w, "Session not found", http.StatusUnauthorized) return } sessionData, err := h.sessionManager.GetSession(r.Context(), cookie.Value) if err != nil { http.Error(w, "Invalid session", http.StatusUnauthorized) return } if err := h.sessionManager.RefreshSession(r.Context(), cookie.Value); err == nil { http.SetCookie(w, h.newSessionCookie(r, cookie.Value)) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "user": sessionData, "expires_in": int(h.sessionManager.TTL().Seconds()), }) } // Me returns the currently authenticated user profile. func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) { sessionCookie, err := r.Cookie(sessionCookieName) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } sessionData, err := h.sessionManager.GetSession(r.Context(), sessionCookie.Value) if err != nil { h.clearSessionCookie(w, r) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } user, err := h.authService.GetUserProfile(r.Context(), sessionData.UserID) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } if err := h.sessionManager.RefreshSession(r.Context(), sessionCookie.Value); err == nil { http.SetCookie(w, h.newSessionCookie(r, sessionCookie.Value)) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "user": user, "expires_in": int(h.sessionManager.TTL().Seconds()), }) } // Health check endpoint func (h *AuthHandler) Health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{ "status": "healthy", }) } func isSecureRequest(r *http.Request) bool { if r.TLS != nil { return true } return strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https") } func buildBackendURL(r *http.Request, path string) string { scheme := "http" if isSecureRequest(r) { scheme = "https" } return scheme + "://" + r.Host + path } func buildFrontendLoginURL(status, message string) string { frontendURL := os.Getenv("FRONTEND_URL") if frontendURL == "" { frontendURL = "http://localhost:5173" } parsed, err := url.Parse(strings.TrimRight(frontendURL, "/") + "/login") if err != nil { return frontendURL + "/login" } query := parsed.Query() if status != "" { query.Set("status", status) } if message != "" { query.Set("message", message) } parsed.RawQuery = query.Encode() return parsed.String() } func (h *AuthHandler) setSessionCookie(w http.ResponseWriter, r *http.Request, user *dto.UserDTO) error { if user == nil { return nil } sessionID, err := h.sessionManager.CreateSession(r.Context(), &auth.SessionData{ UserID: user.ID, Email: user.Email, Username: user.Username, }) if err != nil { return err } http.SetCookie(w, h.newSessionCookie(r, sessionID)) return nil } func (h *AuthHandler) newSessionCookie(r *http.Request, sessionID string) *http.Cookie { return &http.Cookie{ Name: sessionCookieName, Value: sessionID, Path: "/", MaxAge: int(h.sessionManager.TTL().Seconds()), HttpOnly: true, Secure: isSecureRequest(r), SameSite: http.SameSiteLaxMode, } } func (h *AuthHandler) clearSessionCookie(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ Name: sessionCookieName, Value: "", Path: "/", MaxAge: -1, HttpOnly: true, Secure: isSecureRequest(r), SameSite: http.SameSiteLaxMode, }) }