1 Commits

Author SHA1 Message Date
domrichardson
6e642da57a fix: fixes to session storage
All checks were successful
Build and Push App Image / build-and-push (push) Successful in 1m27s
2026-03-26 10:06:07 +00:00
17 changed files with 498 additions and 275 deletions

View File

@@ -26,3 +26,9 @@ CORS_ALLOWED_ORIGINS=http://localhost:5173,http://localhost:3000
# Rate Limiting # Rate Limiting
RATE_LIMIT_REQUESTS=50 RATE_LIMIT_REQUESTS=50
RATE_LIMIT_WINDOW=1s RATE_LIMIT_WINDOW=1s
# Redis Sessions
REDIS_ADDR=localhost:6379
REDIS_PASSWORD=
REDIS_DB=0
SESSION_TTL_HOURS=168

View File

@@ -6,6 +6,7 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"strconv"
"strings" "strings"
"time" "time"
@@ -19,6 +20,7 @@ import (
"github.com/noteapp/backend/internal/infrastructure/security" "github.com/noteapp/backend/internal/infrastructure/security"
"github.com/noteapp/backend/internal/interfaces/handlers" "github.com/noteapp/backend/internal/interfaces/handlers"
"github.com/noteapp/backend/internal/interfaces/middleware" "github.com/noteapp/backend/internal/interfaces/middleware"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/bson"
) )
@@ -47,6 +49,30 @@ func main() {
port = "8080" port = "8080"
} }
redisAddr := os.Getenv("REDIS_ADDR")
if redisAddr == "" {
redisAddr = "localhost:6379"
}
redisPassword := os.Getenv("REDIS_PASSWORD")
redisDB := 0
if redisDBText := os.Getenv("REDIS_DB"); redisDBText != "" {
parsedDB, err := strconv.Atoi(redisDBText)
if err != nil {
log.Fatalf("invalid REDIS_DB value: %v", err)
}
redisDB = parsedDB
}
sessionTTL := 7 * 24 * time.Hour
if sessionTTLText := os.Getenv("SESSION_TTL_HOURS"); sessionTTLText != "" {
hours, err := strconv.Atoi(sessionTTLText)
if err != nil || hours <= 0 {
log.Fatalf("invalid SESSION_TTL_HOURS value: %q", sessionTTLText)
}
sessionTTL = time.Duration(hours) * time.Hour
}
// Connect to database // Connect to database
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
@@ -57,6 +83,19 @@ func main() {
} }
defer db.Close(context.Background()) defer db.Close(context.Background())
redisClient := redis.NewClient(&redis.Options{
Addr: redisAddr,
Password: redisPassword,
DB: redisDB,
})
if err := redisClient.Ping(context.Background()).Err(); err != nil {
log.Fatalf("failed to connect to redis: %v", err)
}
defer func() {
_ = redisClient.Close()
}()
// Initialize security components // Initialize security components
passwordHasher := security.NewPasswordHasher() passwordHasher := security.NewPasswordHasher()
encryptor, err := security.NewEncryptor(encryptionKey) encryptor, err := security.NewEncryptor(encryptionKey)
@@ -66,6 +105,7 @@ func main() {
// Initialize JWT manager // Initialize JWT manager
jwtManager := auth.NewJWTManager(jwtSecret, "noteapp", 1*time.Hour) jwtManager := auth.NewJWTManager(jwtSecret, "noteapp", 1*time.Hour)
sessionManager := auth.NewSessionManager(redisClient, sessionTTL)
// Initialize services // Initialize services
permissionService := services.NewPermissionService( permissionService := services.NewPermissionService(
@@ -143,7 +183,7 @@ func main() {
} }
// Initialize handlers // Initialize handlers
authHandler := handlers.NewAuthHandler(authService) authHandler := handlers.NewAuthHandler(authService, sessionManager)
spaceHandler := handlers.NewSpaceHandler(spaceService) spaceHandler := handlers.NewSpaceHandler(spaceService)
noteHandler := handlers.NewNoteHandler(noteService) noteHandler := handlers.NewNoteHandler(noteService)
categoryHandler := handlers.NewCategoryHandler(categoryService) categoryHandler := handlers.NewCategoryHandler(categoryService)
@@ -160,7 +200,7 @@ func main() {
}) })
// Middleware // Middleware
authMiddleware := middleware.NewAuthMiddleware(jwtManager) authMiddleware := middleware.NewAuthMiddleware(jwtManager, sessionManager)
router.Use(middleware.LoggingMiddleware) router.Use(middleware.LoggingMiddleware)
router.Use(middleware.CORSMiddleware) router.Use(middleware.CORSMiddleware)
router.Use(middleware.SecurityHeaders) router.Use(middleware.SecurityHeaders)
@@ -187,6 +227,7 @@ func main() {
// Protected endpoints // Protected endpoints
api := router.PathPrefix("/api/v1").Subrouter() api := router.PathPrefix("/api/v1").Subrouter()
api.Use(authMiddleware.Middleware) api.Use(authMiddleware.Middleware)
api.HandleFunc("/auth/me", authHandler.Me).Methods("GET")
// Space endpoints // Space endpoints
api.HandleFunc("/spaces", spaceHandler.GetUserSpaces).Methods("GET") api.HandleFunc("/spaces", spaceHandler.GetUserSpaces).Methods("GET")

View File

@@ -3,18 +3,20 @@ module github.com/noteapp/backend
go 1.25.0 go 1.25.0
require ( require (
github.com/aws/aws-sdk-go-v2 v1.41.4
github.com/aws/aws-sdk-go-v2/credentials v1.19.12
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.2
github.com/golang-jwt/jwt/v5 v5.2.0 github.com/golang-jwt/jwt/v5 v5.2.0
github.com/gorilla/mux v1.8.1 github.com/gorilla/mux v1.8.1
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/redis/go-redis/v9 v9.18.0
go.mongodb.org/mongo-driver/v2 v2.5.0 go.mongodb.org/mongo-driver/v2 v2.5.0
golang.org/x/crypto v0.49.0 golang.org/x/crypto v0.49.0
golang.org/x/oauth2 v0.30.0 golang.org/x/oauth2 v0.30.0
) )
require ( require (
github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 // indirect
@@ -22,13 +24,15 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 // indirect
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.2 // indirect
github.com/aws/smithy-go v1.24.2 // indirect github.com/aws/smithy-go v1.24.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/klauspost/compress v1.17.6 // indirect github.com/klauspost/compress v1.17.6 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.2.0 // indirect github.com/xdg-go/scram v1.2.0 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/sync v0.20.0 // indirect golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect golang.org/x/text v0.35.0 // indirect

View File

@@ -22,8 +22,16 @@ github.com/aws/aws-sdk-go-v2/service/s3 v1.97.2 h1:MRNiP6nqa20aEl8fQ6PJpEq11b2d4
github.com/aws/aws-sdk-go-v2/service/s3 v1.97.2/go.mod h1:FrNA56srbsr3WShiaelyWYEo70x80mXnVZ17ZZfbeqg= github.com/aws/aws-sdk-go-v2/service/s3 v1.97.2/go.mod h1:FrNA56srbsr3WShiaelyWYEo70x80mXnVZ17ZZfbeqg=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
@@ -34,6 +42,14 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
@@ -43,8 +59,12 @@ github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gi
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=

View File

@@ -114,20 +114,7 @@ func (s *AuthService) Register(ctx context.Context, req *dto.RegisterRequest) (*
} }
} }
// Generate tokens
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
if err != nil {
return nil, err
}
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
if err != nil {
return nil, err
}
return &dto.LoginResponse{ return &dto.LoginResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
User: dto.NewUserDTO(user), User: dto.NewUserDTO(user),
ExpiresIn: 3600, // 1 hour ExpiresIn: 3600, // 1 hour
}, nil }, nil
@@ -165,20 +152,7 @@ func (s *AuthService) Login(ctx context.Context, req *dto.LoginRequest) (*dto.Lo
// Log error but don't fail the login // Log error but don't fail the login
} }
// Generate tokens
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
if err != nil {
return nil, err
}
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
if err != nil {
return nil, err
}
return &dto.LoginResponse{ return &dto.LoginResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
User: dto.NewUserDTO(user), User: dto.NewUserDTO(user),
ExpiresIn: 3600, ExpiresIn: 3600,
}, nil }, nil
@@ -186,6 +160,10 @@ func (s *AuthService) Login(ctx context.Context, req *dto.LoginRequest) (*dto.Lo
// RefreshAccessToken refreshes an access token // RefreshAccessToken refreshes an access token
func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken string) (string, error) { func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken string) (string, error) {
if s.jwtManager == nil {
return "", errors.New("jwt refresh is unavailable")
}
claims, err := s.jwtManager.VerifyRefreshToken(refreshToken) claims, err := s.jwtManager.VerifyRefreshToken(refreshToken)
if err != nil { if err != nil {
return "", err return "", err
@@ -199,6 +177,27 @@ func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken strin
return s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username) return s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
} }
// GetUserProfile returns profile DTO for the provided user ID.
func (s *AuthService) GetUserProfile(ctx context.Context, userID string) (*dto.UserDTO, error) {
objID, err := bson.ObjectIDFromHex(strings.TrimSpace(userID))
if err != nil {
return nil, errors.New("invalid user id")
}
user, err := s.userRepo.GetUserByID(ctx, objID)
if err != nil {
return nil, err
}
if s.permissionService != nil {
if err := s.permissionService.UpdateUserEffectivePermissions(ctx, user); err != nil {
return nil, err
}
}
return dto.NewUserDTO(user), nil
}
// RequestPasswordReset initiates password reset flow // RequestPasswordReset initiates password reset flow
func (s *AuthService) RequestPasswordReset(ctx context.Context, email string) error { func (s *AuthService) RequestPasswordReset(ctx context.Context, email string) error {
user, err := s.userRepo.GetUserByEmail(ctx, email) user, err := s.userRepo.GetUserByEmail(ctx, email)
@@ -444,17 +443,7 @@ func (s *AuthService) CompleteProviderLogin(ctx context.Context, providerID bson
return nil, err return nil, err
} }
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username) return &dto.LoginResponse{User: dto.NewUserDTO(user), ExpiresIn: 3600}, nil
if err != nil {
return nil, err
}
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
if err != nil {
return nil, err
}
return &dto.LoginResponse{AccessToken: accessToken, RefreshToken: refreshToken, User: dto.NewUserDTO(user), ExpiresIn: 3600}, nil
} }
type providerProfile struct { type providerProfile struct {

View File

@@ -0,0 +1,114 @@
package auth
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/redis/go-redis/v9"
)
// SessionData stores authenticated identity data in Redis.
type SessionData struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
}
// SessionManager handles Redis-backed session lifecycle operations.
type SessionManager struct {
redis *redis.Client
ttl time.Duration
prefix string
}
func NewSessionManager(redisClient *redis.Client, ttl time.Duration) *SessionManager {
if ttl <= 0 {
ttl = 7 * 24 * time.Hour
}
return &SessionManager{
redis: redisClient,
ttl: ttl,
prefix: "session:",
}
}
func (m *SessionManager) TTL() time.Duration {
return m.ttl
}
func (m *SessionManager) CreateSession(ctx context.Context, data *SessionData) (string, error) {
if data == nil {
return "", errors.New("session data is required")
}
sessionID, err := GenerateRandomToken(32)
if err != nil {
return "", err
}
payload, err := json.Marshal(data)
if err != nil {
return "", err
}
if err := m.redis.Set(ctx, m.key(sessionID), payload, m.ttl).Err(); err != nil {
return "", err
}
return sessionID, nil
}
func (m *SessionManager) GetSession(ctx context.Context, sessionID string) (*SessionData, error) {
if sessionID == "" {
return nil, errors.New("session id is required")
}
payload, err := m.redis.Get(ctx, m.key(sessionID)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, errors.New("session not found")
}
return nil, err
}
var data SessionData
if err := json.Unmarshal([]byte(payload), &data); err != nil {
return nil, err
}
return &data, nil
}
func (m *SessionManager) RefreshSession(ctx context.Context, sessionID string) error {
if sessionID == "" {
return errors.New("session id is required")
}
if err := m.redis.Expire(ctx, m.key(sessionID), m.ttl).Err(); err != nil {
if errors.Is(err, redis.Nil) {
return errors.New("session not found")
}
return err
}
return nil
}
func (m *SessionManager) DeleteSession(ctx context.Context, sessionID string) error {
if sessionID == "" {
return nil
}
if err := m.redis.Del(ctx, m.key(sessionID)).Err(); err != nil {
return err
}
return nil
}
func (m *SessionManager) key(sessionID string) string {
return m.prefix + sessionID
}

View File

@@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/url" "net/url"
@@ -18,15 +17,19 @@ import (
// AuthHandler handles authentication endpoints // AuthHandler handles authentication endpoints
type AuthHandler struct { type AuthHandler struct {
authService *services.AuthService authService *services.AuthService
sessionManager *auth.SessionManager
} }
// NewAuthHandler creates a new auth handler // NewAuthHandler creates a new auth handler
func NewAuthHandler(authService *services.AuthService) *AuthHandler { func NewAuthHandler(authService *services.AuthService, sessionManager *auth.SessionManager) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
authService: authService, authService: authService,
sessionManager: sessionManager,
} }
} }
const sessionCookieName = "session_id"
// Register handles user registration // Register handles user registration
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@@ -56,6 +59,11 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := h.setSessionCookie(w, r, response.User); err != nil {
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response) json.NewEncoder(w).Encode(response)
} }
@@ -79,16 +87,10 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
return return
} }
// Set secure HTTP-only cookie for refresh token if err := h.setSessionCookie(w, r, response.User); err != nil {
http.SetCookie(w, &http.Cookie{ http.Error(w, "Failed to create session", http.StatusInternalServerError)
Name: "refresh_token", return
Value: response.RefreshToken, }
Path: "/",
MaxAge: 7 * 24 * 60 * 60, // 7 days
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response) json.NewEncoder(w).Encode(response)
@@ -96,15 +98,12 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
// Logout handles user logout // Logout handles user logout
func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) { func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// Clear refresh token cookie sessionCookie, err := r.Cookie(sessionCookieName)
http.SetCookie(w, &http.Cookie{ if err == nil {
Name: "refresh_token", _ = h.sessionManager.DeleteSession(r.Context(), sessionCookie.Value)
Value: "", }
Path: "/",
MaxAge: -1, h.clearSessionCookie(w, r)
HttpOnly: true,
Secure: isSecureRequest(r),
})
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"}) json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"})
@@ -215,7 +214,7 @@ func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Reque
response, err := h.authService.CompleteProviderLogin(r.Context(), providerID, r.URL.Query().Get("code"), buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback")) response, err := h.authService.CompleteProviderLogin(r.Context(), providerID, r.URL.Query().Get("code"), buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback"))
if err != nil { if err != nil {
http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error(), "", nil), http.StatusFound) http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error()), http.StatusFound)
return return
} }
@@ -229,17 +228,12 @@ func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Reque
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}) })
http.SetCookie(w, &http.Cookie{ if err := h.setSessionCookie(w, r, response.User); err != nil {
Name: "refresh_token", http.Redirect(w, r, buildFrontendLoginURL("oauth_error", "Failed to create session"), http.StatusFound)
Value: response.RefreshToken, return
Path: "/", }
MaxAge: 7 * 24 * 60 * 60,
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
http.Redirect(w, r, buildFrontendLoginURL("oauth_success", "", response.AccessToken, response.User), http.StatusFound) http.Redirect(w, r, buildFrontendLoginURL("oauth_success", ""), http.StatusFound)
} }
// RefreshToken handles token refresh // RefreshToken handles token refresh
@@ -249,23 +243,57 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
return return
} }
// Get refresh token from cookie cookie, err := r.Cookie(sessionCookieName)
cookie, err := r.Cookie("refresh_token")
if err != nil { if err != nil {
http.Error(w, "Refresh token not found", http.StatusUnauthorized) http.Error(w, "Session not found", http.StatusUnauthorized)
return return
} }
accessToken, err := h.authService.RefreshAccessToken(r.Context(), cookie.Value) sessionData, err := h.sessionManager.GetSession(r.Context(), cookie.Value)
if err != nil { if err != nil {
http.Error(w, "Invalid refresh token", http.StatusUnauthorized) http.Error(w, "Invalid session", http.StatusUnauthorized)
return return
} }
if err := h.sessionManager.RefreshSession(r.Context(), cookie.Value); err == nil {
http.SetCookie(w, h.newSessionCookie(r, cookie.Value))
}
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{ json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": accessToken, "user": sessionData,
"expires_in": 3600, "expires_in": int(h.sessionManager.TTL().Seconds()),
})
}
// Me returns the currently authenticated user profile.
func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
sessionCookie, err := r.Cookie(sessionCookieName)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
sessionData, err := h.sessionManager.GetSession(r.Context(), sessionCookie.Value)
if err != nil {
h.clearSessionCookie(w, r)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
user, err := h.authService.GetUserProfile(r.Context(), sessionData.UserID)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if err := h.sessionManager.RefreshSession(r.Context(), sessionCookie.Value); err == nil {
http.SetCookie(w, h.newSessionCookie(r, sessionCookie.Value))
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"user": user,
"expires_in": int(h.sessionManager.TTL().Seconds()),
}) })
} }
@@ -292,7 +320,7 @@ func buildBackendURL(r *http.Request, path string) string {
return scheme + "://" + r.Host + path return scheme + "://" + r.Host + path
} }
func buildFrontendLoginURL(status, message, accessToken string, user *dto.UserDTO) string { func buildFrontendLoginURL(status, message string) string {
frontendURL := os.Getenv("FRONTEND_URL") frontendURL := os.Getenv("FRONTEND_URL")
if frontendURL == "" { if frontendURL == "" {
frontendURL = "http://localhost:5173" frontendURL = "http://localhost:5173"
@@ -310,14 +338,48 @@ func buildFrontendLoginURL(status, message, accessToken string, user *dto.UserDT
if message != "" { if message != "" {
query.Set("message", message) query.Set("message", message)
} }
if accessToken != "" {
query.Set("access_token", accessToken)
}
if user != nil {
payload, _ := json.Marshal(user)
query.Set("user_json", string(payload))
query.Set("user", base64.RawURLEncoding.EncodeToString(payload))
}
parsed.RawQuery = query.Encode() parsed.RawQuery = query.Encode()
return parsed.String() return parsed.String()
} }
func (h *AuthHandler) setSessionCookie(w http.ResponseWriter, r *http.Request, user *dto.UserDTO) error {
if user == nil {
return nil
}
sessionID, err := h.sessionManager.CreateSession(r.Context(), &auth.SessionData{
UserID: user.ID,
Email: user.Email,
Username: user.Username,
})
if err != nil {
return err
}
http.SetCookie(w, h.newSessionCookie(r, sessionID))
return nil
}
func (h *AuthHandler) newSessionCookie(r *http.Request, sessionID string) *http.Cookie {
return &http.Cookie{
Name: sessionCookieName,
Value: sessionID,
Path: "/",
MaxAge: int(h.sessionManager.TTL().Seconds()),
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
}
}
func (h *AuthHandler) clearSessionCookie(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: isSecureRequest(r),
SameSite: http.SameSiteLaxMode,
})
}

View File

@@ -21,12 +21,14 @@ const (
// AuthMiddleware verifies JWT tokens // AuthMiddleware verifies JWT tokens
type AuthMiddleware struct { type AuthMiddleware struct {
jwtManager *auth.JWTManager jwtManager *auth.JWTManager
sessionManager *auth.SessionManager
} }
// NewAuthMiddleware creates a new auth middleware // NewAuthMiddleware creates a new auth middleware
func NewAuthMiddleware(jwtManager *auth.JWTManager) *AuthMiddleware { func NewAuthMiddleware(jwtManager *auth.JWTManager, sessionManager *auth.SessionManager) *AuthMiddleware {
return &AuthMiddleware{ return &AuthMiddleware{
jwtManager: jwtManager, jwtManager: jwtManager,
sessionManager: sessionManager,
} }
} }
@@ -41,16 +43,23 @@ func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler {
return return
} }
// Extract token from Authorization header. if sessionCookie, err := r.Cookie("session_id"); err == nil && sessionCookie.Value != "" {
// For GET /files/object, also accept ?token= so markdown images render in-browser. sessionData, sessionErr := m.sessionManager.GetSession(r.Context(), sessionCookie.Value)
if sessionErr == nil {
_ = m.sessionManager.RefreshSession(r.Context(), sessionCookie.Value)
ctx := context.WithValue(r.Context(), UserIDKey, sessionData.UserID)
ctx = context.WithValue(ctx, EmailKey, sessionData.Email)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
return
}
}
// Fall back to Authorization header for backwards compatibility.
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
if authHeader == "" && r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/files/object") {
if tok := r.URL.Query().Get("token"); tok != "" {
authHeader = "Bearer " + tok
}
}
if authHeader == "" { if authHeader == "" {
http.Error(w, "Missing authorization header", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }

View File

@@ -79,6 +79,7 @@ func CORSMiddleware(next http.Handler) http.Handler {
} }
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "600") w.Header().Set("Access-Control-Max-Age", "600")
if r.Method == http.MethodOptions { if r.Method == http.MethodOptions {

View File

@@ -44,20 +44,6 @@ http {
listen 80; listen 80;
server_name localhost; server_name localhost;
# API routes
location /api/ {
limit_req zone=api_limit burst=20 nodelay;
proxy_pass http://notely;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_connect_timeout 60s;
proxy_send_timeout 60s;
proxy_read_timeout 60s;
}
# Health check # Health check
location /health { location /health {
proxy_pass http://notely; proxy_pass http://notely;

View File

@@ -1,6 +1,17 @@
version: "3.8"
services: services:
redis:
image: redis:8-alpine
container_name: notely-redis
ports:
- "6379:6379"
networks:
- notely-network
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
mongodb: mongodb:
image: mongo:8.0 image: mongo:8.0
container_name: notely-mongodb container_name: notely-mongodb
@@ -39,9 +50,15 @@ services:
DEFAULT_ADMIN_EMAIL: ${DEFAULT_ADMIN_EMAIL} DEFAULT_ADMIN_EMAIL: ${DEFAULT_ADMIN_EMAIL}
DEFAULT_ADMIN_USERNAME: ${DEFAULT_ADMIN_USERNAME} DEFAULT_ADMIN_USERNAME: ${DEFAULT_ADMIN_USERNAME}
DEFAULT_ADMIN_PASSWORD: ${DEFAULT_ADMIN_PASSWORD} DEFAULT_ADMIN_PASSWORD: ${DEFAULT_ADMIN_PASSWORD}
REDIS_ADDR: ${REDIS_ADDR}
REDIS_PASSWORD: ${REDIS_PASSWORD}
REDIS_DB: ${REDIS_DB}
SESSION_TTL_HOURS: ${SESSION_TTL_HOURS}
depends_on: depends_on:
mongodb: mongodb:
condition: service_healthy condition: service_healthy
redis:
condition: service_healthy
networks: networks:
- notely-network - notely-network

View File

@@ -2,7 +2,6 @@
<div class="note-editor"> <div class="note-editor">
<div class="editor-toolbar mb-3"> <div class="editor-toolbar mb-3">
<button class="btn btn-sm btn-primary" @click="saveNote">Save</button> <button class="btn btn-sm btn-primary" @click="saveNote">Save</button>
<button v-if="canDelete" class="btn btn-sm btn-danger ms-2" @click="confirmDelete">Delete</button>
<button class="btn btn-sm btn-outline-secondary ms-2" @click="emit('cancel')">Cancel</button> <button class="btn btn-sm btn-outline-secondary ms-2" @click="emit('cancel')">Cancel</button>
<button <button
v-if="fileExplorerEnabled" v-if="fileExplorerEnabled"
@@ -82,6 +81,15 @@
</select> </select>
<input v-if="passwordAction === 'set'" v-model="notePassword" type="password" class="form-control mt-2" minlength="4" maxlength="128" placeholder="Enter a note password" /> <input v-if="passwordAction === 'set'" v-model="notePassword" type="password" class="form-control mt-2" minlength="4" maxlength="128" placeholder="Enter a note password" />
</div> </div>
<section v-if="canDelete && editingNote.id" class="danger-zone mt-4" aria-labelledby="danger-zone-title">
<h3 id="danger-zone-title" class="danger-zone-title mb-2">Danger Zone</h3>
<p class="danger-zone-copy mb-3">Deleting this note is permanent and cannot be undone.</p>
<button class="btn btn-danger" type="button" @click="confirmDelete">
<i class="mdi mdi-delete-outline me-1" aria-hidden="true"></i>
Delete Note
</button>
</section>
</div> </div>
</div> </div>
</template> </template>
@@ -91,7 +99,6 @@ import { ref, computed, watch, onBeforeUnmount, onMounted, nextTick } from "vue"
import { marked } from "marked"; import { marked } from "marked";
import DOMPurify from "dompurify"; import DOMPurify from "dompurify";
import { useSettingsStore } from "../stores/settingsStore"; import { useSettingsStore } from "../stores/settingsStore";
import { useAuthStore } from "../stores/authStore";
import { preprocessMarkdown } from "../utils/markdown.js"; import { preprocessMarkdown } from "../utils/markdown.js";
import FileExplorer from "./FileExplorer.vue"; import FileExplorer from "./FileExplorer.vue";
@@ -116,7 +123,6 @@ const props = defineProps({
const emit = defineEmits(["save", "delete", "cancel"]); const emit = defineEmits(["save", "delete", "cancel"]);
const settingsStore = useSettingsStore(); const settingsStore = useSettingsStore();
const authStore = useAuthStore();
const publicSharingEnabled = ref(true); const publicSharingEnabled = ref(true);
const fileExplorerEnabled = computed(() => settingsStore.fileExplorerEnabled); const fileExplorerEnabled = computed(() => settingsStore.fileExplorerEnabled);
@@ -133,17 +139,7 @@ const saveStateTimeout = ref(null);
const renderedMarkdown = computed(() => { const renderedMarkdown = computed(() => {
const html = marked.parse(preprocessMarkdown(editingNote.value.content || "")); const html = marked.parse(preprocessMarkdown(editingNote.value.content || ""));
let clean = DOMPurify.sanitize(html); return DOMPurify.sanitize(html);
// Inject access token into space file API URLs so images render without a separate JS fetch
const token = authStore.accessToken;
if (token && props.spaceId) {
clean = clean.replace(/((?:src|href)=["'])([^"']*\/api\/v1\/spaces\/[^"']*\/files\/object[^"']*)(["'])/g, (_, attr, url, quote) => {
if (url.includes("token=")) return attr + url + quote;
const sep = url.includes("?") ? "&" : "?";
return `${attr}${url}${sep}token=${encodeURIComponent(token)}${quote}`;
});
}
return clean;
}); });
const saveStatusLabel = computed(() => { const saveStatusLabel = computed(() => {
@@ -294,7 +290,7 @@ onMounted(async () => {
.editor-textarea { .editor-textarea {
font-family: "Courier New", monospace; font-family: "Courier New", monospace;
min-height: 400px; min-height: 600px;
resize: vertical; resize: vertical;
} }
@@ -333,4 +329,22 @@ onMounted(async () => {
overflow-y: auto; overflow-y: auto;
max-height: 600px; max-height: 600px;
} }
.danger-zone {
padding: 1rem;
border: 1px solid #f3b5b5;
border-radius: 0.75rem;
background: #fff5f5;
}
.danger-zone-title {
color: #9f1c1c;
font-size: 1rem;
font-weight: 700;
}
.danger-zone-copy {
color: #7a2727;
font-size: 0.9rem;
}
</style> </style>

View File

@@ -32,7 +32,6 @@
import { computed } from "vue"; import { computed } from "vue";
import { marked } from "marked"; import { marked } from "marked";
import DOMPurify from "dompurify"; import DOMPurify from "dompurify";
import { useAuthStore } from "../stores/authStore";
import { preprocessMarkdown } from "../utils/markdown.js"; import { preprocessMarkdown } from "../utils/markdown.js";
const props = defineProps({ const props = defineProps({
@@ -50,20 +49,9 @@ const props = defineProps({
}, },
}); });
const authStore = useAuthStore();
const renderedMarkdown = computed(() => { const renderedMarkdown = computed(() => {
const html = marked.parse(preprocessMarkdown(props.note.content || "")); const html = marked.parse(preprocessMarkdown(props.note.content || ""));
let clean = DOMPurify.sanitize(html); return DOMPurify.sanitize(html);
const token = authStore.accessToken;
if (token && props.spaceId) {
clean = clean.replace(/((?:src|href)=["'])([^"']*\/api\/v1\/spaces\/[^"']*\/files\/object[^"']*)(["'])/g, (_, attr, url, quote) => {
if (url.includes("token=")) return attr + url + quote;
const sep = url.includes("?") ? "&" : "?";
return `${attr}${url}${sep}token=${encodeURIComponent(token)}${quote}`;
});
}
return clean;
}); });
const categoryLabel = computed(() => { const categoryLabel = computed(() => {

View File

@@ -88,73 +88,33 @@ const startProviderLogin = (providerId) => {
window.location.href = `${apiClient.defaults.baseURL}/api/v1/auth/providers/${providerId}/start`; window.location.href = `${apiClient.defaults.baseURL}/api/v1/auth/providers/${providerId}/start`;
}; };
const decodeBase64Url = (value) => {
const normalized = value.replace(/-/g, "+").replace(/_/g, "/");
const padding = normalized.length % 4;
const padded = padding === 0 ? normalized : `${normalized}${"=".repeat(4 - padding)}`;
return atob(padded);
};
const decodeBase64UrlUTF8 = (value) => {
const binary = decodeBase64Url(value);
const bytes = Uint8Array.from(binary, (ch) => ch.charCodeAt(0));
return new TextDecoder().decode(bytes);
};
const readUserFromQuery = (params) => {
const plainUserJSON = params.get("user_json");
if (plainUserJSON) {
return JSON.parse(plainUserJSON);
}
const encodedUser = params.get("user");
if (encodedUser) {
return JSON.parse(decodeBase64UrlUTF8(encodedUser));
}
return null;
};
const completeOAuthRedirect = async () => { const completeOAuthRedirect = async () => {
const params = new URLSearchParams(window.location.search); const params = new URLSearchParams(window.location.search);
const status = params.get("status"); const status = params.get("status");
const accessToken = params.get("access_token") || params.get("accessToken") || params.get("token");
if (status === "oauth_error") { if (status === "oauth_error") {
error.value = params.get("message") || "Provider sign-in failed."; error.value = params.get("message") || "Provider sign-in failed.";
return true; return true;
} }
// Accept callback payloads even when `status` is missing. if (status !== "oauth_success") {
if (status !== "oauth_success" && !accessToken) {
if (status === "oauth_error") {
error.value = params.get("message") || "Provider sign-in failed.";
}
return false; return false;
} }
if (!accessToken) {
error.value = "Provider sign-in returned an incomplete session.";
return true;
}
try { try {
const user = readUserFromQuery(params); await authStore.ensureInitialized();
if (!user) {
error.value = "Provider sign-in returned an incomplete session.";
return true;
}
authStore.setSession({ access_token: accessToken, user });
await router.replace("/");
} catch { } catch {
error.value = "Unable to restore the provider session."; error.value = "Unable to restore provider session.";
return true;
} }
if (authStore.isAuthenticated) { if (authStore.isAuthenticated) {
window.location.replace("/"); await router.replace("/");
return true;
} }
error.value = "Provider sign-in returned an incomplete session.";
return true; return true;
}; };
@@ -163,6 +123,8 @@ onMounted(async () => {
registrationEnabled.value = !!flags.registration_enabled; registrationEnabled.value = !!flags.registration_enabled;
providerLoginEnabled.value = !!flags.provider_login_enabled; providerLoginEnabled.value = !!flags.provider_login_enabled;
await authStore.ensureInitialized();
if (authStore.isAuthenticated) { if (authStore.isAuthenticated) {
await router.replace("/"); await router.replace("/");
return; return;

View File

@@ -4,39 +4,6 @@ import { useSettingsStore } from "../stores/settingsStore";
import LoginPage from "../pages/Login.vue"; import LoginPage from "../pages/Login.vue";
import RegisterPage from "../pages/Register.vue"; import RegisterPage from "../pages/Register.vue";
const decodeBase64UrlUTF8 = (value) => {
const normalized = value.replace(/-/g, "+").replace(/_/g, "/");
const padding = normalized.length % 4;
const padded = padding === 0 ? normalized : `${normalized}${"=".repeat(4 - padding)}`;
const binary = atob(padded);
const bytes = Uint8Array.from(binary, (ch) => ch.charCodeAt(0));
return new TextDecoder().decode(bytes);
};
const restoreOAuthSessionFromQuery = (query, authStore) => {
// Merge router query with URLSearchParams for full coverage
const params = new URLSearchParams(window.location.search);
const accessToken = query.access_token || query.accessToken || query.token || params.get("access_token") || params.get("accessToken") || params.get("token");
if (!accessToken) {
return false;
}
try {
const plainUserJSON = query.user_json || params.get("user_json");
const encodedUser = query.user || params.get("user");
const user = plainUserJSON ? JSON.parse(plainUserJSON) : encodedUser ? JSON.parse(decodeBase64UrlUTF8(encodedUser)) : null;
if (!user) {
return false;
}
authStore.setSession({ access_token: accessToken, user });
return true;
} catch {
return false;
}
};
const routes = [ const routes = [
{ {
path: "/login", path: "/login",
@@ -81,25 +48,7 @@ router.beforeEach(async (to, from, next) => {
const authStore = useAuthStore(); const authStore = useAuthStore();
const settingsStore = useSettingsStore(); const settingsStore = useSettingsStore();
// Only attempt OAuth callback restoration if actual OAuth query params are present await authStore.ensureInitialized();
const params = new URLSearchParams(window.location.search);
const hasOAuthParams = to.query.access_token || to.query.accessToken || to.query.token || params.get("access_token") || params.get("accessToken") || params.get("token");
if (to.path === "/login") {
if (hasOAuthParams) {
const restored = restoreOAuthSessionFromQuery(to.query, authStore);
if (restored) {
next({ path: "/", replace: true });
return;
}
}
// Allow login page to be viewed regardless of auth state if no OAuth callback
if (!hasOAuthParams) {
next();
return;
}
}
if (to.path === "/register") { if (to.path === "/register") {
await settingsStore.loadFeatureFlags(); await settingsStore.loadFeatureFlags();

View File

@@ -3,23 +3,57 @@ import { useAuthStore } from "../stores/authStore";
const apiClient = axios.create({ const apiClient = axios.create({
baseURL: import.meta.env.VITE_API_BASE_URL || "http://localhost:8080", baseURL: import.meta.env.VITE_API_BASE_URL || "http://localhost:8080",
withCredentials: true,
}); });
apiClient.interceptors.request.use((config) => { let isRefreshing = false;
const authStore = useAuthStore(); let refreshSubscribers = [];
if (authStore.accessToken) {
config.headers.Authorization = `Bearer ${authStore.accessToken}`; function onRefreshed() {
} refreshSubscribers.forEach((cb) => cb());
return config; refreshSubscribers = [];
}); }
apiClient.interceptors.response.use( apiClient.interceptors.response.use(
(response) => response, (response) => response,
(error) => { async (error) => {
if (error.response?.status === 401) { const originalRequest = error.config;
if (error.response?.status === 401 && !originalRequest._retry) {
// Avoid retrying the refresh request itself
if (originalRequest.url?.includes("/auth/refresh") || originalRequest.url?.includes("/auth/login")) {
const authStore = useAuthStore(); const authStore = useAuthStore();
authStore.logout(); authStore.clearSession();
return Promise.reject(error);
} }
if (isRefreshing) {
// Queue the request until the ongoing refresh completes
return new Promise((resolve, reject) => {
refreshSubscribers.push(() => {
originalRequest._retry = true;
apiClient(originalRequest).then(resolve).catch(reject);
});
});
}
originalRequest._retry = true;
isRefreshing = true;
try {
await apiClient.post("/api/v1/auth/refresh");
onRefreshed();
return apiClient(originalRequest);
} catch {
refreshSubscribers = [];
const authStore = useAuthStore();
authStore.clearSession();
return Promise.reject(error);
} finally {
isRefreshing = false;
}
}
return Promise.reject(error); return Promise.reject(error);
}, },
); );

View File

@@ -3,10 +3,11 @@ import { ref, computed } from "vue";
import apiClient from "../services/apiClient"; import apiClient from "../services/apiClient";
export const useAuthStore = defineStore("auth", () => { export const useAuthStore = defineStore("auth", () => {
const storedUser = localStorage.getItem("user"); const user = ref(null);
const user = ref(storedUser ? JSON.parse(storedUser) : null); const initialized = ref(false);
const accessToken = ref(localStorage.getItem("accessToken")); let initPromise = null;
const isAuthenticated = computed(() => !!accessToken.value && !!user.value);
const isAuthenticated = computed(() => !!user.value);
const isAdmin = computed(() => hasPermission("*") || hasPermission("admin.access")); const isAdmin = computed(() => hasPermission("*") || hasPermission("admin.access"));
const normalizePermission = (permission) => (permission || "").trim().toLowerCase(); const normalizePermission = (permission) => (permission || "").trim().toLowerCase();
@@ -46,10 +47,36 @@ export const useAuthStore = defineStore("auth", () => {
}; };
const setSession = (responseData) => { const setSession = (responseData) => {
accessToken.value = responseData.access_token; user.value = responseData?.user || null;
user.value = responseData.user; initialized.value = true;
localStorage.setItem("accessToken", accessToken.value); };
localStorage.setItem("user", JSON.stringify(user.value));
const clearSession = () => {
user.value = null;
initialized.value = true;
};
const loadSession = async () => {
try {
const response = await apiClient.get("/api/v1/auth/me");
user.value = response.data?.user || null;
} catch {
user.value = null;
} finally {
initialized.value = true;
}
};
const ensureInitialized = async () => {
if (initialized.value) {
return;
}
if (!initPromise) {
initPromise = loadSession().finally(() => {
initPromise = null;
});
}
await initPromise;
}; };
const register = async (email, username, password, firstName = "", lastName = "") => { const register = async (email, username, password, firstName = "", lastName = "") => {
@@ -87,20 +114,20 @@ export const useAuthStore = defineStore("auth", () => {
}; };
const logout = () => { const logout = () => {
accessToken.value = null; apiClient.post("/api/v1/auth/logout").catch(() => {});
user.value = null; clearSession();
localStorage.removeItem("accessToken");
localStorage.removeItem("user");
}; };
return { return {
user, user,
accessToken, initialized,
isAuthenticated, isAuthenticated,
isAdmin, isAdmin,
hasPermission, hasPermission,
hasSpacePermission, hasSpacePermission,
setSession, setSession,
clearSession,
ensureInitialized,
register, register,
login, login,
logout, logout,