Files
notely/backend/internal/interfaces/handlers/auth_handler.go
domrichardson df40cc57e1 first commit
2026-03-24 16:03:04 +00:00

300 lines
8.2 KiB
Go

package handlers
import (
"encoding/base64"
"encoding/json"
"net/http"
"net/url"
"os"
"strings"
"github.com/gorilla/mux"
"github.com/noteapp/backend/internal/application/dto"
"github.com/noteapp/backend/internal/application/services"
"github.com/noteapp/backend/internal/infrastructure/auth"
"go.mongodb.org/mongo-driver/v2/bson"
)
// AuthHandler handles authentication endpoints
type AuthHandler struct {
authService *services.AuthService
}
// NewAuthHandler creates a new auth handler
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
return &AuthHandler{
authService: authService,
}
}
// Register handles user registration
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req dto.RegisterRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Basic validation
if req.Email == "" || req.Password == "" || req.Username == "" {
http.Error(w, "Missing required fields", http.StatusBadRequest)
return
}
response, err := h.authService.Register(r.Context(), &req)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "registration is currently disabled") {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// Login handles user login
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req dto.LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
response, err := h.authService.Login(r.Context(), &req)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Set secure HTTP-only cookie for refresh token
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: response.RefreshToken,
Path: "/",
MaxAge: 7 * 24 * 60 * 60, // 7 days
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
// Logout handles user logout
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// Clear refresh token cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: isSecureRequest(r),
})
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"})
}
// ListProviders returns all active OAuth/OIDC providers.
func (h *AuthHandler) ListProviders(w http.ResponseWriter, r *http.Request) {
providers, err := h.authService.ListProviders(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"providers": providers})
}
// CreateProvider stores a new OAuth/OIDC provider configuration.
func (h *AuthHandler) CreateProvider(w http.ResponseWriter, r *http.Request) {
var req dto.CreateAuthProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
provider, err := h.authService.CreateProvider(r.Context(), &req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(provider)
}
// StartProviderLogin redirects the browser to the selected provider.
func (h *AuthHandler) StartProviderLogin(w http.ResponseWriter, r *http.Request) {
providerID, err := bson.ObjectIDFromHex(mux.Vars(r)["providerId"])
if err != nil {
http.Error(w, "Invalid provider ID", http.StatusBadRequest)
return
}
state, err := auth.GenerateStateToken()
if err != nil {
http.Error(w, "Failed to create OAuth state", http.StatusInternalServerError)
return
}
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: state,
Path: "/",
MaxAge: 10 * 60,
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
redirectURI := buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback")
authorizationURL, err := h.authService.BuildProviderAuthorizationURL(r.Context(), providerID, redirectURI, state)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Redirect(w, r, authorizationURL, http.StatusFound)
}
// CompleteProviderLogin exchanges the authorization code and redirects back to the frontend.
func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Request) {
providerID, err := bson.ObjectIDFromHex(mux.Vars(r)["providerId"])
if err != nil {
http.Error(w, "Invalid provider ID", http.StatusBadRequest)
return
}
stateCookie, err := r.Cookie("oauth_state")
if err != nil || stateCookie.Value == "" || stateCookie.Value != r.URL.Query().Get("state") {
http.Error(w, "Invalid OAuth state", http.StatusBadRequest)
return
}
response, err := h.authService.CompleteProviderLogin(r.Context(), providerID, r.URL.Query().Get("code"), buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback"))
if err != nil {
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error(), "", nil), http.StatusFound)
return
}
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: response.RefreshToken,
Path: "/",
MaxAge: 7 * 24 * 60 * 60,
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
http.Redirect(w, r, buildFrontendLoginURL("oauth_success", "", response.AccessToken, response.User), http.StatusFound)
}
// RefreshToken handles token refresh
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Get refresh token from cookie
cookie, err := r.Cookie("refresh_token")
if err != nil {
http.Error(w, "Refresh token not found", http.StatusUnauthorized)
return
}
accessToken, err := h.authService.RefreshAccessToken(r.Context(), cookie.Value)
if err != nil {
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": accessToken,
"expires_in": 3600,
})
}
// Health check endpoint
func (h *AuthHandler) Health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{
"status": "healthy",
})
}
func isSecureRequest(r *http.Request) bool {
if r.TLS != nil {
return true
}
return strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https")
}
func buildBackendURL(r *http.Request, path string) string {
scheme := "http"
if isSecureRequest(r) {
scheme = "https"
}
return scheme + "://" + r.Host + path
}
func buildFrontendLoginURL(status, message, accessToken string, user *dto.UserDTO) string {
frontendURL := os.Getenv("FRONTEND_URL")
if frontendURL == "" {
frontendURL = "http://localhost:5173"
}
parsed, err := url.Parse(strings.TrimRight(frontendURL, "/") + "/login")
if err != nil {
return frontendURL + "/login"
}
query := parsed.Query()
if status != "" {
query.Set("status", status)
}
if message != "" {
query.Set("message", message)
}
if accessToken != "" {
query.Set("access_token", accessToken)
}
if user != nil {
payload, _ := json.Marshal(user)
query.Set("user_json", string(payload))
query.Set("user", base64.RawURLEncoding.EncodeToString(payload))
}
parsed.RawQuery = query.Encode()
return parsed.String()
}