fix: fixes to session storage
All checks were successful
Build and Push App Image / build-and-push (push) Successful in 1m27s

This commit is contained in:
domrichardson
2026-03-26 10:06:07 +00:00
parent 6774c401bf
commit 6e642da57a
17 changed files with 498 additions and 275 deletions

View File

@@ -114,22 +114,9 @@ func (s *AuthService) Register(ctx context.Context, req *dto.RegisterRequest) (*
}
}
// Generate tokens
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
if err != nil {
return nil, err
}
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
if err != nil {
return nil, err
}
return &dto.LoginResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
User: dto.NewUserDTO(user),
ExpiresIn: 3600, // 1 hour
User: dto.NewUserDTO(user),
ExpiresIn: 3600, // 1 hour
}, nil
}
@@ -165,27 +152,18 @@ func (s *AuthService) Login(ctx context.Context, req *dto.LoginRequest) (*dto.Lo
// Log error but don't fail the login
}
// Generate tokens
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
if err != nil {
return nil, err
}
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
if err != nil {
return nil, err
}
return &dto.LoginResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
User: dto.NewUserDTO(user),
ExpiresIn: 3600,
User: dto.NewUserDTO(user),
ExpiresIn: 3600,
}, nil
}
// RefreshAccessToken refreshes an access token
func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken string) (string, error) {
if s.jwtManager == nil {
return "", errors.New("jwt refresh is unavailable")
}
claims, err := s.jwtManager.VerifyRefreshToken(refreshToken)
if err != nil {
return "", err
@@ -199,6 +177,27 @@ func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken strin
return s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
}
// GetUserProfile returns profile DTO for the provided user ID.
func (s *AuthService) GetUserProfile(ctx context.Context, userID string) (*dto.UserDTO, error) {
objID, err := bson.ObjectIDFromHex(strings.TrimSpace(userID))
if err != nil {
return nil, errors.New("invalid user id")
}
user, err := s.userRepo.GetUserByID(ctx, objID)
if err != nil {
return nil, err
}
if s.permissionService != nil {
if err := s.permissionService.UpdateUserEffectivePermissions(ctx, user); err != nil {
return nil, err
}
}
return dto.NewUserDTO(user), nil
}
// RequestPasswordReset initiates password reset flow
func (s *AuthService) RequestPasswordReset(ctx context.Context, email string) error {
user, err := s.userRepo.GetUserByEmail(ctx, email)
@@ -444,17 +443,7 @@ func (s *AuthService) CompleteProviderLogin(ctx context.Context, providerID bson
return nil, err
}
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
if err != nil {
return nil, err
}
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
if err != nil {
return nil, err
}
return &dto.LoginResponse{AccessToken: accessToken, RefreshToken: refreshToken, User: dto.NewUserDTO(user), ExpiresIn: 3600}, nil
return &dto.LoginResponse{User: dto.NewUserDTO(user), ExpiresIn: 3600}, nil
}
type providerProfile struct {

View File

@@ -0,0 +1,114 @@
package auth
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/redis/go-redis/v9"
)
// SessionData stores authenticated identity data in Redis.
type SessionData struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
}
// SessionManager handles Redis-backed session lifecycle operations.
type SessionManager struct {
redis *redis.Client
ttl time.Duration
prefix string
}
func NewSessionManager(redisClient *redis.Client, ttl time.Duration) *SessionManager {
if ttl <= 0 {
ttl = 7 * 24 * time.Hour
}
return &SessionManager{
redis: redisClient,
ttl: ttl,
prefix: "session:",
}
}
func (m *SessionManager) TTL() time.Duration {
return m.ttl
}
func (m *SessionManager) CreateSession(ctx context.Context, data *SessionData) (string, error) {
if data == nil {
return "", errors.New("session data is required")
}
sessionID, err := GenerateRandomToken(32)
if err != nil {
return "", err
}
payload, err := json.Marshal(data)
if err != nil {
return "", err
}
if err := m.redis.Set(ctx, m.key(sessionID), payload, m.ttl).Err(); err != nil {
return "", err
}
return sessionID, nil
}
func (m *SessionManager) GetSession(ctx context.Context, sessionID string) (*SessionData, error) {
if sessionID == "" {
return nil, errors.New("session id is required")
}
payload, err := m.redis.Get(ctx, m.key(sessionID)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, errors.New("session not found")
}
return nil, err
}
var data SessionData
if err := json.Unmarshal([]byte(payload), &data); err != nil {
return nil, err
}
return &data, nil
}
func (m *SessionManager) RefreshSession(ctx context.Context, sessionID string) error {
if sessionID == "" {
return errors.New("session id is required")
}
if err := m.redis.Expire(ctx, m.key(sessionID), m.ttl).Err(); err != nil {
if errors.Is(err, redis.Nil) {
return errors.New("session not found")
}
return err
}
return nil
}
func (m *SessionManager) DeleteSession(ctx context.Context, sessionID string) error {
if sessionID == "" {
return nil
}
if err := m.redis.Del(ctx, m.key(sessionID)).Err(); err != nil {
return err
}
return nil
}
func (m *SessionManager) key(sessionID string) string {
return m.prefix + sessionID
}

View File

@@ -1,7 +1,6 @@
package handlers
import (
"encoding/base64"
"encoding/json"
"net/http"
"net/url"
@@ -17,16 +16,20 @@ import (
// AuthHandler handles authentication endpoints
type AuthHandler struct {
authService *services.AuthService
authService *services.AuthService
sessionManager *auth.SessionManager
}
// NewAuthHandler creates a new auth handler
func NewAuthHandler(authService *services.AuthService) *AuthHandler {
func NewAuthHandler(authService *services.AuthService, sessionManager *auth.SessionManager) *AuthHandler {
return &AuthHandler{
authService: authService,
authService: authService,
sessionManager: sessionManager,
}
}
const sessionCookieName = "session_id"
// Register handles user registration
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -56,6 +59,11 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.setSessionCookie(w, r, response.User); err != nil {
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
@@ -79,16 +87,10 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
return
}
// Set secure HTTP-only cookie for refresh token
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: response.RefreshToken,
Path: "/",
MaxAge: 7 * 24 * 60 * 60, // 7 days
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
if err := h.setSessionCookie(w, r, response.User); err != nil {
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
@@ -96,15 +98,12 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
// Logout handles user logout
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// Clear refresh token cookie
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: isSecureRequest(r),
})
sessionCookie, err := r.Cookie(sessionCookieName)
if err == nil {
_ = h.sessionManager.DeleteSession(r.Context(), sessionCookie.Value)
}
h.clearSessionCookie(w, r)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"})
@@ -215,7 +214,7 @@ func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Reque
response, err := h.authService.CompleteProviderLogin(r.Context(), providerID, r.URL.Query().Get("code"), buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback"))
if err != nil {
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error(), "", nil), http.StatusFound)
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error()), http.StatusFound)
return
}
@@ -229,17 +228,12 @@ func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Reque
SameSite: http.SameSiteLaxMode,
})
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: response.RefreshToken,
Path: "/",
MaxAge: 7 * 24 * 60 * 60,
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
if err := h.setSessionCookie(w, r, response.User); err != nil {
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", "Failed to create session"), http.StatusFound)
return
}
http.Redirect(w, r, buildFrontendLoginURL("oauth_success", "", response.AccessToken, response.User), http.StatusFound)
http.Redirect(w, r, buildFrontendLoginURL("oauth_success", ""), http.StatusFound)
}
// RefreshToken handles token refresh
@@ -249,23 +243,57 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
return
}
// Get refresh token from cookie
cookie, err := r.Cookie("refresh_token")
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
http.Error(w, "Refresh token not found", http.StatusUnauthorized)
http.Error(w, "Session not found", http.StatusUnauthorized)
return
}
accessToken, err := h.authService.RefreshAccessToken(r.Context(), cookie.Value)
sessionData, err := h.sessionManager.GetSession(r.Context(), cookie.Value)
if err != nil {
http.Error(w, "Invalid refresh token", http.StatusUnauthorized)
http.Error(w, "Invalid session", http.StatusUnauthorized)
return
}
if err := h.sessionManager.RefreshSession(r.Context(), cookie.Value); err == nil {
http.SetCookie(w, h.newSessionCookie(r, cookie.Value))
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": accessToken,
"expires_in": 3600,
"user": sessionData,
"expires_in": int(h.sessionManager.TTL().Seconds()),
})
}
// Me returns the currently authenticated user profile.
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
sessionCookie, err := r.Cookie(sessionCookieName)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
sessionData, err := h.sessionManager.GetSession(r.Context(), sessionCookie.Value)
if err != nil {
h.clearSessionCookie(w, r)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
user, err := h.authService.GetUserProfile(r.Context(), sessionData.UserID)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if err := h.sessionManager.RefreshSession(r.Context(), sessionCookie.Value); err == nil {
http.SetCookie(w, h.newSessionCookie(r, sessionCookie.Value))
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"user": user,
"expires_in": int(h.sessionManager.TTL().Seconds()),
})
}
@@ -292,7 +320,7 @@ func buildBackendURL(r *http.Request, path string) string {
return scheme + "://" + r.Host + path
}
func buildFrontendLoginURL(status, message, accessToken string, user *dto.UserDTO) string {
func buildFrontendLoginURL(status, message string) string {
frontendURL := os.Getenv("FRONTEND_URL")
if frontendURL == "" {
frontendURL = "http://localhost:5173"
@@ -310,14 +338,48 @@ func buildFrontendLoginURL(status, message, accessToken string, user *dto.UserDT
if message != "" {
query.Set("message", message)
}
if accessToken != "" {
query.Set("access_token", accessToken)
}
if user != nil {
payload, _ := json.Marshal(user)
query.Set("user_json", string(payload))
query.Set("user", base64.RawURLEncoding.EncodeToString(payload))
}
parsed.RawQuery = query.Encode()
return parsed.String()
}
func (h *AuthHandler) setSessionCookie(w http.ResponseWriter, r *http.Request, user *dto.UserDTO) error {
if user == nil {
return nil
}
sessionID, err := h.sessionManager.CreateSession(r.Context(), &auth.SessionData{
UserID: user.ID,
Email: user.Email,
Username: user.Username,
})
if err != nil {
return err
}
http.SetCookie(w, h.newSessionCookie(r, sessionID))
return nil
}
func (h *AuthHandler) newSessionCookie(r *http.Request, sessionID string) *http.Cookie {
return &http.Cookie{
Name: sessionCookieName,
Value: sessionID,
Path: "/",
MaxAge: int(h.sessionManager.TTL().Seconds()),
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
}
}
func (h *AuthHandler) clearSessionCookie(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
}

View File

@@ -20,13 +20,15 @@ const (
// AuthMiddleware verifies JWT tokens
type AuthMiddleware struct {
jwtManager *auth.JWTManager
jwtManager *auth.JWTManager
sessionManager *auth.SessionManager
}
// NewAuthMiddleware creates a new auth middleware
func NewAuthMiddleware(jwtManager *auth.JWTManager) *AuthMiddleware {
func NewAuthMiddleware(jwtManager *auth.JWTManager, sessionManager *auth.SessionManager) *AuthMiddleware {
return &AuthMiddleware{
jwtManager: jwtManager,
jwtManager: jwtManager,
sessionManager: sessionManager,
}
}
@@ -41,16 +43,23 @@ func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
return
}
// Extract token from Authorization header.
// For GET /files/object, also accept ?token= so markdown images render in-browser.
authHeader := r.Header.Get("Authorization")
if authHeader == "" && r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/files/object") {
if tok := r.URL.Query().Get("token"); tok != "" {
authHeader = "Bearer " + tok
if sessionCookie, err := r.Cookie("session_id"); err == nil && sessionCookie.Value != "" {
sessionData, sessionErr := m.sessionManager.GetSession(r.Context(), sessionCookie.Value)
if sessionErr == nil {
_ = m.sessionManager.RefreshSession(r.Context(), sessionCookie.Value)
ctx := context.WithValue(r.Context(), UserIDKey, sessionData.UserID)
ctx = context.WithValue(ctx, EmailKey, sessionData.Email)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
return
}
}
// Fall back to Authorization header for backwards compatibility.
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Missing authorization header", http.StatusUnauthorized)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}

View File

@@ -79,6 +79,7 @@ func CORSMiddleware(next http.Handler) http.Handler {
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "600")
if r.Method == http.MethodOptions {