fix: fixes to session storage
All checks were successful
Build and Push App Image / build-and-push (push) Successful in 1m27s
All checks were successful
Build and Push App Image / build-and-push (push) Successful in 1m27s
This commit is contained in:
@@ -114,22 +114,9 @@ func (s *AuthService) Register(ctx context.Context, req *dto.RegisterRequest) (*
|
||||
}
|
||||
}
|
||||
|
||||
// Generate tokens
|
||||
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dto.LoginResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
User: dto.NewUserDTO(user),
|
||||
ExpiresIn: 3600, // 1 hour
|
||||
User: dto.NewUserDTO(user),
|
||||
ExpiresIn: 3600, // 1 hour
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -165,27 +152,18 @@ func (s *AuthService) Login(ctx context.Context, req *dto.LoginRequest) (*dto.Lo
|
||||
// Log error but don't fail the login
|
||||
}
|
||||
|
||||
// Generate tokens
|
||||
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dto.LoginResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
User: dto.NewUserDTO(user),
|
||||
ExpiresIn: 3600,
|
||||
User: dto.NewUserDTO(user),
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshAccessToken refreshes an access token
|
||||
func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken string) (string, error) {
|
||||
if s.jwtManager == nil {
|
||||
return "", errors.New("jwt refresh is unavailable")
|
||||
}
|
||||
|
||||
claims, err := s.jwtManager.VerifyRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -199,6 +177,27 @@ func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken strin
|
||||
return s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
|
||||
}
|
||||
|
||||
// GetUserProfile returns profile DTO for the provided user ID.
|
||||
func (s *AuthService) GetUserProfile(ctx context.Context, userID string) (*dto.UserDTO, error) {
|
||||
objID, err := bson.ObjectIDFromHex(strings.TrimSpace(userID))
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid user id")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetUserByID(ctx, objID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.permissionService != nil {
|
||||
if err := s.permissionService.UpdateUserEffectivePermissions(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return dto.NewUserDTO(user), nil
|
||||
}
|
||||
|
||||
// RequestPasswordReset initiates password reset flow
|
||||
func (s *AuthService) RequestPasswordReset(ctx context.Context, email string) error {
|
||||
user, err := s.userRepo.GetUserByEmail(ctx, email)
|
||||
@@ -444,17 +443,7 @@ func (s *AuthService) CompleteProviderLogin(ctx context.Context, providerID bson
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dto.LoginResponse{AccessToken: accessToken, RefreshToken: refreshToken, User: dto.NewUserDTO(user), ExpiresIn: 3600}, nil
|
||||
return &dto.LoginResponse{User: dto.NewUserDTO(user), ExpiresIn: 3600}, nil
|
||||
}
|
||||
|
||||
type providerProfile struct {
|
||||
|
||||
114
backend/internal/infrastructure/auth/session.go
Normal file
114
backend/internal/infrastructure/auth/session.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// SessionData stores authenticated identity data in Redis.
|
||||
type SessionData struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
// SessionManager handles Redis-backed session lifecycle operations.
|
||||
type SessionManager struct {
|
||||
redis *redis.Client
|
||||
ttl time.Duration
|
||||
prefix string
|
||||
}
|
||||
|
||||
func NewSessionManager(redisClient *redis.Client, ttl time.Duration) *SessionManager {
|
||||
if ttl <= 0 {
|
||||
ttl = 7 * 24 * time.Hour
|
||||
}
|
||||
|
||||
return &SessionManager{
|
||||
redis: redisClient,
|
||||
ttl: ttl,
|
||||
prefix: "session:",
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) TTL() time.Duration {
|
||||
return m.ttl
|
||||
}
|
||||
|
||||
func (m *SessionManager) CreateSession(ctx context.Context, data *SessionData) (string, error) {
|
||||
if data == nil {
|
||||
return "", errors.New("session data is required")
|
||||
}
|
||||
|
||||
sessionID, err := GenerateRandomToken(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := m.redis.Set(ctx, m.key(sessionID), payload, m.ttl).Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) GetSession(ctx context.Context, sessionID string) (*SessionData, error) {
|
||||
if sessionID == "" {
|
||||
return nil, errors.New("session id is required")
|
||||
}
|
||||
|
||||
payload, err := m.redis.Get(ctx, m.key(sessionID)).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data SessionData
|
||||
if err := json.Unmarshal([]byte(payload), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) RefreshSession(ctx context.Context, sessionID string) error {
|
||||
if sessionID == "" {
|
||||
return errors.New("session id is required")
|
||||
}
|
||||
|
||||
if err := m.redis.Expire(ctx, m.key(sessionID), m.ttl).Err(); err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return errors.New("session not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
if sessionID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.redis.Del(ctx, m.key(sessionID)).Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) key(sessionID string) string {
|
||||
return m.prefix + sessionID
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -17,16 +16,20 @@ import (
|
||||
|
||||
// AuthHandler handles authentication endpoints
|
||||
type AuthHandler struct {
|
||||
authService *services.AuthService
|
||||
authService *services.AuthService
|
||||
sessionManager *auth.SessionManager
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new auth handler
|
||||
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
|
||||
func NewAuthHandler(authService *services.AuthService, sessionManager *auth.SessionManager) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
authService: authService,
|
||||
sessionManager: sessionManager,
|
||||
}
|
||||
}
|
||||
|
||||
const sessionCookieName = "session_id"
|
||||
|
||||
// Register handles user registration
|
||||
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -56,6 +59,11 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.setSessionCookie(w, r, response.User); err != nil {
|
||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
@@ -79,16 +87,10 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set secure HTTP-only cookie for refresh token
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: response.RefreshToken,
|
||||
Path: "/",
|
||||
MaxAge: 7 * 24 * 60 * 60, // 7 days
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
if err := h.setSessionCookie(w, r, response.User); err != nil {
|
||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
@@ -96,15 +98,12 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Logout handles user logout
|
||||
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
// Clear refresh token cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
})
|
||||
sessionCookie, err := r.Cookie(sessionCookieName)
|
||||
if err == nil {
|
||||
_ = h.sessionManager.DeleteSession(r.Context(), sessionCookie.Value)
|
||||
}
|
||||
|
||||
h.clearSessionCookie(w, r)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"})
|
||||
@@ -215,7 +214,7 @@ func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
response, err := h.authService.CompleteProviderLogin(r.Context(), providerID, r.URL.Query().Get("code"), buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback"))
|
||||
if err != nil {
|
||||
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error(), "", nil), http.StatusFound)
|
||||
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error()), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -229,17 +228,12 @@ func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Reque
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refresh_token",
|
||||
Value: response.RefreshToken,
|
||||
Path: "/",
|
||||
MaxAge: 7 * 24 * 60 * 60,
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
if err := h.setSessionCookie(w, r, response.User); err != nil {
|
||||
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", "Failed to create session"), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, buildFrontendLoginURL("oauth_success", "", response.AccessToken, response.User), http.StatusFound)
|
||||
http.Redirect(w, r, buildFrontendLoginURL("oauth_success", ""), http.StatusFound)
|
||||
}
|
||||
|
||||
// RefreshToken handles token refresh
|
||||
@@ -249,23 +243,57 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get refresh token from cookie
|
||||
cookie, err := r.Cookie("refresh_token")
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
http.Error(w, "Refresh token not found", http.StatusUnauthorized)
|
||||
http.Error(w, "Session not found", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, err := h.authService.RefreshAccessToken(r.Context(), cookie.Value)
|
||||
sessionData, err := h.sessionManager.GetSession(r.Context(), cookie.Value)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
|
||||
http.Error(w, "Invalid session", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if err := h.sessionManager.RefreshSession(r.Context(), cookie.Value); err == nil {
|
||||
http.SetCookie(w, h.newSessionCookie(r, cookie.Value))
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"access_token": accessToken,
|
||||
"expires_in": 3600,
|
||||
"user": sessionData,
|
||||
"expires_in": int(h.sessionManager.TTL().Seconds()),
|
||||
})
|
||||
}
|
||||
|
||||
// Me returns the currently authenticated user profile.
|
||||
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
|
||||
sessionCookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
sessionData, err := h.sessionManager.GetSession(r.Context(), sessionCookie.Value)
|
||||
if err != nil {
|
||||
h.clearSessionCookie(w, r)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.authService.GetUserProfile(r.Context(), sessionData.UserID)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.sessionManager.RefreshSession(r.Context(), sessionCookie.Value); err == nil {
|
||||
http.SetCookie(w, h.newSessionCookie(r, sessionCookie.Value))
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"user": user,
|
||||
"expires_in": int(h.sessionManager.TTL().Seconds()),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -292,7 +320,7 @@ func buildBackendURL(r *http.Request, path string) string {
|
||||
return scheme + "://" + r.Host + path
|
||||
}
|
||||
|
||||
func buildFrontendLoginURL(status, message, accessToken string, user *dto.UserDTO) string {
|
||||
func buildFrontendLoginURL(status, message string) string {
|
||||
frontendURL := os.Getenv("FRONTEND_URL")
|
||||
if frontendURL == "" {
|
||||
frontendURL = "http://localhost:5173"
|
||||
@@ -310,14 +338,48 @@ func buildFrontendLoginURL(status, message, accessToken string, user *dto.UserDT
|
||||
if message != "" {
|
||||
query.Set("message", message)
|
||||
}
|
||||
if accessToken != "" {
|
||||
query.Set("access_token", accessToken)
|
||||
}
|
||||
if user != nil {
|
||||
payload, _ := json.Marshal(user)
|
||||
query.Set("user_json", string(payload))
|
||||
query.Set("user", base64.RawURLEncoding.EncodeToString(payload))
|
||||
}
|
||||
parsed.RawQuery = query.Encode()
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
func (h *AuthHandler) setSessionCookie(w http.ResponseWriter, r *http.Request, user *dto.UserDTO) error {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionID, err := h.sessionManager.CreateSession(r.Context(), &auth.SessionData{
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
Username: user.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
http.SetCookie(w, h.newSessionCookie(r, sessionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *AuthHandler) newSessionCookie(r *http.Request, sessionID string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
MaxAge: int(h.sessionManager.TTL().Seconds()),
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AuthHandler) clearSessionCookie(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: isSecureRequest(r),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,13 +20,15 @@ const (
|
||||
|
||||
// AuthMiddleware verifies JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
jwtManager *auth.JWTManager
|
||||
jwtManager *auth.JWTManager
|
||||
sessionManager *auth.SessionManager
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new auth middleware
|
||||
func NewAuthMiddleware(jwtManager *auth.JWTManager) *AuthMiddleware {
|
||||
func NewAuthMiddleware(jwtManager *auth.JWTManager, sessionManager *auth.SessionManager) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
jwtManager: jwtManager,
|
||||
jwtManager: jwtManager,
|
||||
sessionManager: sessionManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,16 +43,23 @@ func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract token from Authorization header.
|
||||
// For GET /files/object, also accept ?token= so markdown images render in-browser.
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" && r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/files/object") {
|
||||
if tok := r.URL.Query().Get("token"); tok != "" {
|
||||
authHeader = "Bearer " + tok
|
||||
if sessionCookie, err := r.Cookie("session_id"); err == nil && sessionCookie.Value != "" {
|
||||
sessionData, sessionErr := m.sessionManager.GetSession(r.Context(), sessionCookie.Value)
|
||||
if sessionErr == nil {
|
||||
_ = m.sessionManager.RefreshSession(r.Context(), sessionCookie.Value)
|
||||
|
||||
ctx := context.WithValue(r.Context(), UserIDKey, sessionData.UserID)
|
||||
ctx = context.WithValue(ctx, EmailKey, sessionData.Email)
|
||||
r = r.WithContext(ctx)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to Authorization header for backwards compatibility.
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
http.Error(w, "Missing authorization header", http.StatusUnauthorized)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ func CORSMiddleware(next http.Handler) http.Handler {
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Max-Age", "600")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
|
||||
Reference in New Issue
Block a user