first commit
This commit is contained in:
299
backend/internal/interfaces/handlers/auth_handler.go
Normal file
299
backend/internal/interfaces/handlers/auth_handler.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/noteapp/backend/internal/application/dto"
|
||||
"github.com/noteapp/backend/internal/application/services"
|
||||
"github.com/noteapp/backend/internal/infrastructure/auth"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication endpoints
|
||||
type AuthHandler struct {
|
||||
authService *services.AuthService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new auth handler
|
||||
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register handles user registration
|
||||
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.RegisterRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Basic validation
|
||||
if req.Email == "" || req.Password == "" || req.Username == "" {
|
||||
http.Error(w, "Missing required fields", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.authService.Register(r.Context(), &req)
|
||||
if err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "registration is currently disabled") {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// Login handles user login
|
||||
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.LoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.authService.Login(r.Context(), &req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set secure HTTP-only cookie for refresh token
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: response.RefreshToken,
|
||||
Path: "/",
|
||||
MaxAge: 7 * 24 * 60 * 60, // 7 days
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// Logout handles user logout
|
||||
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
// Clear refresh token cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
})
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"})
|
||||
}
|
||||
|
||||
// ListProviders returns all active OAuth/OIDC providers.
|
||||
func (h *AuthHandler) ListProviders(w http.ResponseWriter, r *http.Request) {
|
||||
providers, err := h.authService.ListProviders(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{"providers": providers})
|
||||
}
|
||||
|
||||
// CreateProvider stores a new OAuth/OIDC provider configuration.
|
||||
func (h *AuthHandler) CreateProvider(w http.ResponseWriter, r *http.Request) {
|
||||
var req dto.CreateAuthProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.authService.CreateProvider(r.Context(), &req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(provider)
|
||||
}
|
||||
|
||||
// StartProviderLogin redirects the browser to the selected provider.
|
||||
func (h *AuthHandler) StartProviderLogin(w http.ResponseWriter, r *http.Request) {
|
||||
providerID, err := bson.ObjectIDFromHex(mux.Vars(r)["providerId"])
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid provider ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
state, err := auth.GenerateStateToken()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to create OAuth state", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state,
|
||||
Path: "/",
|
||||
MaxAge: 10 * 60,
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
redirectURI := buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback")
|
||||
authorizationURL, err := h.authService.BuildProviderAuthorizationURL(r.Context(), providerID, redirectURI, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, authorizationURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// CompleteProviderLogin exchanges the authorization code and redirects back to the frontend.
|
||||
func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Request) {
|
||||
providerID, err := bson.ObjectIDFromHex(mux.Vars(r)["providerId"])
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid provider ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
stateCookie, err := r.Cookie("oauth_state")
|
||||
if err != nil || stateCookie.Value == "" || stateCookie.Value != r.URL.Query().Get("state") {
|
||||
http.Error(w, "Invalid OAuth state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.authService.CompleteProviderLogin(r.Context(), providerID, r.URL.Query().Get("code"), buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback"))
|
||||
if err != nil {
|
||||
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error(), "", nil), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: response.RefreshToken,
|
||||
Path: "/",
|
||||
MaxAge: 7 * 24 * 60 * 60,
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, buildFrontendLoginURL("oauth_success", "", response.AccessToken, response.User), http.StatusFound)
|
||||
}
|
||||
|
||||
// RefreshToken handles token refresh
|
||||
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Get refresh token from cookie
|
||||
cookie, err := r.Cookie("refresh_token")
|
||||
if err != nil {
|
||||
http.Error(w, "Refresh token not found", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, err := h.authService.RefreshAccessToken(r.Context(), cookie.Value)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"access_token": accessToken,
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}
|
||||
|
||||
// Health check endpoint
|
||||
func (h *AuthHandler) Health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "healthy",
|
||||
})
|
||||
}
|
||||
|
||||
func isSecureRequest(r *http.Request) bool {
|
||||
if r.TLS != nil {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(r.Header.Get("X-Forwarded-Proto"), "https")
|
||||
}
|
||||
|
||||
func buildBackendURL(r *http.Request, path string) string {
|
||||
scheme := "http"
|
||||
if isSecureRequest(r) {
|
||||
scheme = "https"
|
||||
}
|
||||
return scheme + "://" + r.Host + path
|
||||
}
|
||||
|
||||
func buildFrontendLoginURL(status, message, accessToken string, user *dto.UserDTO) string {
|
||||
frontendURL := os.Getenv("FRONTEND_URL")
|
||||
if frontendURL == "" {
|
||||
frontendURL = "http://localhost:5173"
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(strings.TrimRight(frontendURL, "/") + "/login")
|
||||
if err != nil {
|
||||
return frontendURL + "/login"
|
||||
}
|
||||
|
||||
query := parsed.Query()
|
||||
if status != "" {
|
||||
query.Set("status", status)
|
||||
}
|
||||
if message != "" {
|
||||
query.Set("message", message)
|
||||
}
|
||||
if accessToken != "" {
|
||||
query.Set("access_token", accessToken)
|
||||
}
|
||||
if user != nil {
|
||||
payload, _ := json.Marshal(user)
|
||||
query.Set("user_json", string(payload))
|
||||
query.Set("user", base64.RawURLEncoding.EncodeToString(payload))
|
||||
}
|
||||
parsed.RawQuery = query.Encode()
|
||||
return parsed.String()
|
||||
}
|
||||
Reference in New Issue
Block a user