398 lines
11 KiB
Go
398 lines
11 KiB
Go
package handlers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
|
|
"gitea.hostxtra.co.uk/mrhid6/notely/backend/internal/application/dto"
|
|
"gitea.hostxtra.co.uk/mrhid6/notely/backend/internal/application/services"
|
|
"gitea.hostxtra.co.uk/mrhid6/notely/backend/internal/infrastructure/auth"
|
|
"github.com/gorilla/mux"
|
|
"go.mongodb.org/mongo-driver/v2/bson"
|
|
)
|
|
|
|
// AuthHandler handles authentication endpoints
|
|
type AuthHandler struct {
|
|
authService *services.AuthService
|
|
sessionManager *auth.SessionManager
|
|
}
|
|
|
|
// NewAuthHandler creates a new auth handler
|
|
func NewAuthHandler(authService *services.AuthService, sessionManager *auth.SessionManager) *AuthHandler {
|
|
return &AuthHandler{
|
|
authService: authService,
|
|
sessionManager: sessionManager,
|
|
}
|
|
}
|
|
|
|
const sessionCookieName = "session_id"
|
|
|
|
// Register handles user registration
|
|
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
var req dto.RegisterRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Basic validation
|
|
if req.Email == "" || req.Password == "" || req.Username == "" {
|
|
http.Error(w, "Missing required fields", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
response, err := h.authService.Register(r.Context(), &req)
|
|
if err != nil {
|
|
if strings.Contains(strings.ToLower(err.Error()), "registration is currently disabled") {
|
|
http.Error(w, err.Error(), http.StatusForbidden)
|
|
return
|
|
}
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if err := h.setSessionCookie(w, r, response.User); err != nil {
|
|
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
// Login handles user login
|
|
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
var req dto.LoginRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
response, err := h.authService.Login(r.Context(), &req)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if err := h.setSessionCookie(w, r, response.User); err != nil {
|
|
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}
|
|
|
|
// Logout handles user logout
|
|
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
|
sessionCookie, err := r.Cookie(sessionCookieName)
|
|
if err == nil {
|
|
_ = h.sessionManager.DeleteSession(r.Context(), sessionCookie.Value)
|
|
}
|
|
|
|
h.clearSessionCookie(w, r)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"})
|
|
}
|
|
|
|
// ListProviders returns all active OAuth/OIDC providers.
|
|
func (h *AuthHandler) ListProviders(w http.ResponseWriter, r *http.Request) {
|
|
providers, err := h.authService.ListProviders(r.Context())
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]interface{}{"providers": providers})
|
|
}
|
|
|
|
// ListProvidersForAdmin returns all OAuth/OIDC providers, including inactive ones.
|
|
func (h *AuthHandler) ListProvidersForAdmin(w http.ResponseWriter, r *http.Request) {
|
|
providers, err := h.authService.ListProvidersForAdmin(r.Context())
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]interface{}{"providers": providers})
|
|
}
|
|
|
|
// CreateProvider stores a new OAuth/OIDC provider configuration.
|
|
func (h *AuthHandler) CreateProvider(w http.ResponseWriter, r *http.Request) {
|
|
var req dto.CreateAuthProviderRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
provider, err := h.authService.CreateProvider(r.Context(), &req)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusCreated)
|
|
json.NewEncoder(w).Encode(provider)
|
|
}
|
|
|
|
// UpdateProvider updates an existing OAuth/OIDC provider configuration.
|
|
func (h *AuthHandler) UpdateProvider(w http.ResponseWriter, r *http.Request) {
|
|
providerID, err := bson.ObjectIDFromHex(mux.Vars(r)["providerId"])
|
|
if err != nil {
|
|
http.Error(w, "Invalid provider ID", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var req dto.UpdateAuthProviderRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
provider, err := h.authService.UpdateProvider(r.Context(), providerID, &req)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(provider)
|
|
}
|
|
|
|
// StartProviderLogin redirects the browser to the selected provider.
|
|
func (h *AuthHandler) StartProviderLogin(w http.ResponseWriter, r *http.Request) {
|
|
providerID, err := bson.ObjectIDFromHex(mux.Vars(r)["providerId"])
|
|
if err != nil {
|
|
http.Error(w, "Invalid provider ID", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
state, err := auth.GenerateStateToken()
|
|
if err != nil {
|
|
http.Error(w, "Failed to create OAuth state", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "oauth_state",
|
|
Value: state,
|
|
Path: "/",
|
|
MaxAge: 10 * 60,
|
|
HttpOnly: true,
|
|
Secure: isSecureRequest(r),
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
redirectURI := buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback")
|
|
authorizationURL, err := h.authService.BuildProviderAuthorizationURL(r.Context(), providerID, redirectURI, state)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, authorizationURL, http.StatusFound)
|
|
}
|
|
|
|
// CompleteProviderLogin exchanges the authorization code and redirects back to the frontend.
|
|
func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Request) {
|
|
providerID, err := bson.ObjectIDFromHex(mux.Vars(r)["providerId"])
|
|
if err != nil {
|
|
http.Error(w, "Invalid provider ID", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
stateCookie, err := r.Cookie("oauth_state")
|
|
if err != nil || stateCookie.Value == "" || stateCookie.Value != r.URL.Query().Get("state") {
|
|
http.Error(w, "Invalid OAuth state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
response, err := h.authService.CompleteProviderLogin(r.Context(), providerID, r.URL.Query().Get("code"), buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback"))
|
|
if err != nil {
|
|
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error()), http.StatusFound)
|
|
return
|
|
}
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "oauth_state",
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
Secure: isSecureRequest(r),
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
|
|
if err := h.setSessionCookie(w, r, response.User); err != nil {
|
|
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", "Failed to create session"), http.StatusFound)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, buildFrontendLoginURL("oauth_success", ""), http.StatusFound)
|
|
}
|
|
|
|
// RefreshToken handles token refresh
|
|
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
cookie, err := r.Cookie(sessionCookieName)
|
|
if err != nil {
|
|
http.Error(w, "Session not found", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
sessionData, err := h.sessionManager.GetSession(r.Context(), cookie.Value)
|
|
if err != nil {
|
|
http.Error(w, "Invalid session", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if err := h.sessionManager.RefreshSession(r.Context(), cookie.Value); err == nil {
|
|
http.SetCookie(w, h.newSessionCookie(r, cookie.Value))
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"user": sessionData,
|
|
"expires_in": int(h.sessionManager.TTL().Seconds()),
|
|
})
|
|
}
|
|
|
|
// Me returns the currently authenticated user profile.
|
|
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
|
|
sessionCookie, err := r.Cookie(sessionCookieName)
|
|
if err != nil {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
sessionData, err := h.sessionManager.GetSession(r.Context(), sessionCookie.Value)
|
|
if err != nil {
|
|
h.clearSessionCookie(w, r)
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
user, err := h.authService.GetUserProfile(r.Context(), sessionData.UserID)
|
|
if err != nil {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if err := h.sessionManager.RefreshSession(r.Context(), sessionCookie.Value); err == nil {
|
|
http.SetCookie(w, h.newSessionCookie(r, sessionCookie.Value))
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
|
"user": user,
|
|
"expires_in": int(h.sessionManager.TTL().Seconds()),
|
|
})
|
|
}
|
|
|
|
// Health check endpoint
|
|
func (h *AuthHandler) Health(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"status": "healthy",
|
|
})
|
|
}
|
|
|
|
func isSecureRequest(r *http.Request) bool {
|
|
if r.TLS != nil {
|
|
return true
|
|
}
|
|
return strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https")
|
|
}
|
|
|
|
func buildBackendURL(r *http.Request, path string) string {
|
|
scheme := "http"
|
|
if isSecureRequest(r) {
|
|
scheme = "https"
|
|
}
|
|
return scheme + "://" + r.Host + path
|
|
}
|
|
|
|
func buildFrontendLoginURL(status, message string) string {
|
|
frontendURL := os.Getenv("FRONTEND_URL")
|
|
if frontendURL == "" {
|
|
frontendURL = "http://localhost:5173"
|
|
}
|
|
|
|
parsed, err := url.Parse(strings.TrimRight(frontendURL, "/") + "/login")
|
|
if err != nil {
|
|
return frontendURL + "/login"
|
|
}
|
|
|
|
query := parsed.Query()
|
|
if status != "" {
|
|
query.Set("status", status)
|
|
}
|
|
if message != "" {
|
|
query.Set("message", message)
|
|
}
|
|
parsed.RawQuery = query.Encode()
|
|
return parsed.String()
|
|
}
|
|
|
|
func (h *AuthHandler) setSessionCookie(w http.ResponseWriter, r *http.Request, user *dto.UserDTO) error {
|
|
if user == nil {
|
|
return nil
|
|
}
|
|
|
|
sessionID, err := h.sessionManager.CreateSession(r.Context(), &auth.SessionData{
|
|
UserID: user.ID,
|
|
Email: user.Email,
|
|
Username: user.Username,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
http.SetCookie(w, h.newSessionCookie(r, sessionID))
|
|
return nil
|
|
}
|
|
|
|
func (h *AuthHandler) newSessionCookie(r *http.Request, sessionID string) *http.Cookie {
|
|
return &http.Cookie{
|
|
Name: sessionCookieName,
|
|
Value: sessionID,
|
|
Path: "/",
|
|
MaxAge: int(h.sessionManager.TTL().Seconds()),
|
|
HttpOnly: true,
|
|
Secure: isSecureRequest(r),
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
}
|
|
|
|
func (h *AuthHandler) clearSessionCookie(w http.ResponseWriter, r *http.Request) {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: sessionCookieName,
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
Secure: isSecureRequest(r),
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
}
|