diff --git a/backend/.env.example b/backend/.env.example index 37726b3..2b74a90 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -26,3 +26,9 @@ CORS_ALLOWED_ORIGINS=http://localhost:5173,http://localhost:3000 # Rate Limiting RATE_LIMIT_REQUESTS=50 RATE_LIMIT_WINDOW=1s + +# Redis Sessions +REDIS_ADDR=localhost:6379 +REDIS_PASSWORD= +REDIS_DB=0 +SESSION_TTL_HOURS=168 diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 79030e1..fb7519f 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "os" + "strconv" "strings" "time" @@ -19,6 +20,7 @@ import ( "github.com/noteapp/backend/internal/infrastructure/security" "github.com/noteapp/backend/internal/interfaces/handlers" "github.com/noteapp/backend/internal/interfaces/middleware" + "github.com/redis/go-redis/v9" "go.mongodb.org/mongo-driver/v2/bson" ) @@ -47,6 +49,30 @@ func main() { port = "8080" } + redisAddr := os.Getenv("REDIS_ADDR") + if redisAddr == "" { + redisAddr = "localhost:6379" + } + + redisPassword := os.Getenv("REDIS_PASSWORD") + redisDB := 0 + if redisDBText := os.Getenv("REDIS_DB"); redisDBText != "" { + parsedDB, err := strconv.Atoi(redisDBText) + if err != nil { + log.Fatalf("invalid REDIS_DB value: %v", err) + } + redisDB = parsedDB + } + + sessionTTL := 7 * 24 * time.Hour + if sessionTTLText := os.Getenv("SESSION_TTL_HOURS"); sessionTTLText != "" { + hours, err := strconv.Atoi(sessionTTLText) + if err != nil || hours <= 0 { + log.Fatalf("invalid SESSION_TTL_HOURS value: %q", sessionTTLText) + } + sessionTTL = time.Duration(hours) * time.Hour + } + // Connect to database ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -57,6 +83,19 @@ func main() { } defer db.Close(context.Background()) + redisClient := redis.NewClient(&redis.Options{ + Addr: redisAddr, + Password: redisPassword, + DB: redisDB, + }) + + if err := redisClient.Ping(context.Background()).Err(); err != nil { + log.Fatalf("failed to connect to redis: %v", err) + } + defer func() { + _ = redisClient.Close() + }() + // Initialize security components passwordHasher := security.NewPasswordHasher() encryptor, err := security.NewEncryptor(encryptionKey) @@ -66,6 +105,7 @@ func main() { // Initialize JWT manager jwtManager := auth.NewJWTManager(jwtSecret, "noteapp", 1*time.Hour) + sessionManager := auth.NewSessionManager(redisClient, sessionTTL) // Initialize services permissionService := services.NewPermissionService( @@ -143,7 +183,7 @@ func main() { } // Initialize handlers - authHandler := handlers.NewAuthHandler(authService) + authHandler := handlers.NewAuthHandler(authService, sessionManager) spaceHandler := handlers.NewSpaceHandler(spaceService) noteHandler := handlers.NewNoteHandler(noteService) categoryHandler := handlers.NewCategoryHandler(categoryService) @@ -160,7 +200,7 @@ func main() { }) // Middleware - authMiddleware := middleware.NewAuthMiddleware(jwtManager) + authMiddleware := middleware.NewAuthMiddleware(jwtManager, sessionManager) router.Use(middleware.LoggingMiddleware) router.Use(middleware.CORSMiddleware) router.Use(middleware.SecurityHeaders) @@ -187,6 +227,7 @@ func main() { // Protected endpoints api := router.PathPrefix("/api/v1").Subrouter() api.Use(authMiddleware.Middleware) + api.HandleFunc("/auth/me", authHandler.Me).Methods("GET") // Space endpoints api.HandleFunc("/spaces", spaceHandler.GetUserSpaces).Methods("GET") diff --git a/backend/go.mod b/backend/go.mod index bd37121..b8afd3d 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -3,18 +3,20 @@ module github.com/noteapp/backend go 1.25.0 require ( + github.com/aws/aws-sdk-go-v2 v1.41.4 + github.com/aws/aws-sdk-go-v2/credentials v1.19.12 + github.com/aws/aws-sdk-go-v2/service/s3 v1.97.2 github.com/golang-jwt/jwt/v5 v5.2.0 github.com/gorilla/mux v1.8.1 github.com/joho/godotenv v1.5.1 + github.com/redis/go-redis/v9 v9.18.0 go.mongodb.org/mongo-driver/v2 v2.5.0 golang.org/x/crypto v0.49.0 golang.org/x/oauth2 v0.30.0 ) require ( - github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 // indirect @@ -22,13 +24,15 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 // indirect - github.com/aws/aws-sdk-go-v2/service/s3 v1.97.2 // indirect github.com/aws/smithy-go v1.24.2 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/klauspost/compress v1.17.6 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.2.0 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 149adc2..a9412a1 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -22,8 +22,16 @@ github.com/aws/aws-sdk-go-v2/service/s3 v1.97.2 h1:MRNiP6nqa20aEl8fQ6PJpEq11b2d4 github.com/aws/aws-sdk-go-v2/service/s3 v1.97.2/go.mod h1:FrNA56srbsr3WShiaelyWYEo70x80mXnVZ17ZZfbeqg= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= @@ -34,6 +42,14 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= +github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= @@ -43,8 +59,12 @@ github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gi github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= diff --git a/backend/internal/application/services/auth_service.go b/backend/internal/application/services/auth_service.go index 59227a2..1195849 100644 --- a/backend/internal/application/services/auth_service.go +++ b/backend/internal/application/services/auth_service.go @@ -114,22 +114,9 @@ func (s *AuthService) Register(ctx context.Context, req *dto.RegisterRequest) (* } } - // Generate tokens - accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username) - if err != nil { - return nil, err - } - - refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex()) - if err != nil { - return nil, err - } - return &dto.LoginResponse{ - AccessToken: accessToken, - RefreshToken: refreshToken, - User: dto.NewUserDTO(user), - ExpiresIn: 3600, // 1 hour + User: dto.NewUserDTO(user), + ExpiresIn: 3600, // 1 hour }, nil } @@ -165,27 +152,18 @@ func (s *AuthService) Login(ctx context.Context, req *dto.LoginRequest) (*dto.Lo // Log error but don't fail the login } - // Generate tokens - accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username) - if err != nil { - return nil, err - } - - refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex()) - if err != nil { - return nil, err - } - return &dto.LoginResponse{ - AccessToken: accessToken, - RefreshToken: refreshToken, - User: dto.NewUserDTO(user), - ExpiresIn: 3600, + User: dto.NewUserDTO(user), + ExpiresIn: 3600, }, nil } // RefreshAccessToken refreshes an access token func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken string) (string, error) { + if s.jwtManager == nil { + return "", errors.New("jwt refresh is unavailable") + } + claims, err := s.jwtManager.VerifyRefreshToken(refreshToken) if err != nil { return "", err @@ -199,6 +177,27 @@ func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken strin return s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username) } +// GetUserProfile returns profile DTO for the provided user ID. +func (s *AuthService) GetUserProfile(ctx context.Context, userID string) (*dto.UserDTO, error) { + objID, err := bson.ObjectIDFromHex(strings.TrimSpace(userID)) + if err != nil { + return nil, errors.New("invalid user id") + } + + user, err := s.userRepo.GetUserByID(ctx, objID) + if err != nil { + return nil, err + } + + if s.permissionService != nil { + if err := s.permissionService.UpdateUserEffectivePermissions(ctx, user); err != nil { + return nil, err + } + } + + return dto.NewUserDTO(user), nil +} + // RequestPasswordReset initiates password reset flow func (s *AuthService) RequestPasswordReset(ctx context.Context, email string) error { user, err := s.userRepo.GetUserByEmail(ctx, email) @@ -444,17 +443,7 @@ func (s *AuthService) CompleteProviderLogin(ctx context.Context, providerID bson return nil, err } - accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username) - if err != nil { - return nil, err - } - - refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex()) - if err != nil { - return nil, err - } - - return &dto.LoginResponse{AccessToken: accessToken, RefreshToken: refreshToken, User: dto.NewUserDTO(user), ExpiresIn: 3600}, nil + return &dto.LoginResponse{User: dto.NewUserDTO(user), ExpiresIn: 3600}, nil } type providerProfile struct { diff --git a/backend/internal/infrastructure/auth/session.go b/backend/internal/infrastructure/auth/session.go new file mode 100644 index 0000000..c0675fa --- /dev/null +++ b/backend/internal/infrastructure/auth/session.go @@ -0,0 +1,114 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "time" + + "github.com/redis/go-redis/v9" +) + +// SessionData stores authenticated identity data in Redis. +type SessionData struct { + UserID string `json:"user_id"` + Email string `json:"email"` + Username string `json:"username"` +} + +// SessionManager handles Redis-backed session lifecycle operations. +type SessionManager struct { + redis *redis.Client + ttl time.Duration + prefix string +} + +func NewSessionManager(redisClient *redis.Client, ttl time.Duration) *SessionManager { + if ttl <= 0 { + ttl = 7 * 24 * time.Hour + } + + return &SessionManager{ + redis: redisClient, + ttl: ttl, + prefix: "session:", + } +} + +func (m *SessionManager) TTL() time.Duration { + return m.ttl +} + +func (m *SessionManager) CreateSession(ctx context.Context, data *SessionData) (string, error) { + if data == nil { + return "", errors.New("session data is required") + } + + sessionID, err := GenerateRandomToken(32) + if err != nil { + return "", err + } + + payload, err := json.Marshal(data) + if err != nil { + return "", err + } + + if err := m.redis.Set(ctx, m.key(sessionID), payload, m.ttl).Err(); err != nil { + return "", err + } + + return sessionID, nil +} + +func (m *SessionManager) GetSession(ctx context.Context, sessionID string) (*SessionData, error) { + if sessionID == "" { + return nil, errors.New("session id is required") + } + + payload, err := m.redis.Get(ctx, m.key(sessionID)).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, errors.New("session not found") + } + return nil, err + } + + var data SessionData + if err := json.Unmarshal([]byte(payload), &data); err != nil { + return nil, err + } + + return &data, nil +} + +func (m *SessionManager) RefreshSession(ctx context.Context, sessionID string) error { + if sessionID == "" { + return errors.New("session id is required") + } + + if err := m.redis.Expire(ctx, m.key(sessionID), m.ttl).Err(); err != nil { + if errors.Is(err, redis.Nil) { + return errors.New("session not found") + } + return err + } + + return nil +} + +func (m *SessionManager) DeleteSession(ctx context.Context, sessionID string) error { + if sessionID == "" { + return nil + } + + if err := m.redis.Del(ctx, m.key(sessionID)).Err(); err != nil { + return err + } + + return nil +} + +func (m *SessionManager) key(sessionID string) string { + return m.prefix + sessionID +} diff --git a/backend/internal/interfaces/handlers/auth_handler.go b/backend/internal/interfaces/handlers/auth_handler.go index ee10922..f7747c0 100644 --- a/backend/internal/interfaces/handlers/auth_handler.go +++ b/backend/internal/interfaces/handlers/auth_handler.go @@ -1,7 +1,6 @@ package handlers import ( - "encoding/base64" "encoding/json" "net/http" "net/url" @@ -17,16 +16,20 @@ import ( // AuthHandler handles authentication endpoints type AuthHandler struct { - authService *services.AuthService + authService *services.AuthService + sessionManager *auth.SessionManager } // NewAuthHandler creates a new auth handler -func NewAuthHandler(authService *services.AuthService) *AuthHandler { +func NewAuthHandler(authService *services.AuthService, sessionManager *auth.SessionManager) *AuthHandler { return &AuthHandler{ - authService: authService, + authService: authService, + sessionManager: sessionManager, } } +const sessionCookieName = "session_id" + // Register handles user registration func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -56,6 +59,11 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { return } + if err := h.setSessionCookie(w, r, response.User); err != nil { + http.Error(w, "Failed to create session", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) } @@ -79,16 +87,10 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { return } - // Set secure HTTP-only cookie for refresh token - http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", - Value: response.RefreshToken, - Path: "/", - MaxAge: 7 * 24 * 60 * 60, // 7 days - HttpOnly: true, - Secure: isSecureRequest(r), - SameSite: http.SameSiteLaxMode, - }) + if err := h.setSessionCookie(w, r, response.User); err != nil { + http.Error(w, "Failed to create session", http.StatusInternalServerError) + return + } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(response) @@ -96,15 +98,12 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { // Logout handles user logout func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) { - // Clear refresh token cookie - http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", - Value: "", - Path: "/", - MaxAge: -1, - HttpOnly: true, - Secure: isSecureRequest(r), - }) + sessionCookie, err := r.Cookie(sessionCookieName) + if err == nil { + _ = h.sessionManager.DeleteSession(r.Context(), sessionCookie.Value) + } + + h.clearSessionCookie(w, r) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"message": "Logged out successfully"}) @@ -215,7 +214,7 @@ func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Reque response, err := h.authService.CompleteProviderLogin(r.Context(), providerID, r.URL.Query().Get("code"), buildBackendURL(r, "/api/v1/auth/providers/"+providerID.Hex()+"/callback")) if err != nil { - http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error(), "", nil), http.StatusFound) + http.Redirect(w, r, buildFrontendLoginURL("oauth_error", err.Error()), http.StatusFound) return } @@ -229,17 +228,12 @@ func (h *AuthHandler) CompleteProviderLogin(w http.ResponseWriter, r *http.Reque SameSite: http.SameSiteLaxMode, }) - http.SetCookie(w, &http.Cookie{ - Name: "refresh_token", - Value: response.RefreshToken, - Path: "/", - MaxAge: 7 * 24 * 60 * 60, - HttpOnly: true, - Secure: isSecureRequest(r), - SameSite: http.SameSiteLaxMode, - }) + if err := h.setSessionCookie(w, r, response.User); err != nil { + http.Redirect(w, r, buildFrontendLoginURL("oauth_error", "Failed to create session"), http.StatusFound) + return + } - http.Redirect(w, r, buildFrontendLoginURL("oauth_success", "", response.AccessToken, response.User), http.StatusFound) + http.Redirect(w, r, buildFrontendLoginURL("oauth_success", ""), http.StatusFound) } // RefreshToken handles token refresh @@ -249,23 +243,57 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { return } - // Get refresh token from cookie - cookie, err := r.Cookie("refresh_token") + cookie, err := r.Cookie(sessionCookieName) if err != nil { - http.Error(w, "Refresh token not found", http.StatusUnauthorized) + http.Error(w, "Session not found", http.StatusUnauthorized) return } - accessToken, err := h.authService.RefreshAccessToken(r.Context(), cookie.Value) + sessionData, err := h.sessionManager.GetSession(r.Context(), cookie.Value) if err != nil { - http.Error(w, "Invalid refresh token", http.StatusUnauthorized) + http.Error(w, "Invalid session", http.StatusUnauthorized) return } + if err := h.sessionManager.RefreshSession(r.Context(), cookie.Value); err == nil { + http.SetCookie(w, h.newSessionCookie(r, cookie.Value)) + } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ - "access_token": accessToken, - "expires_in": 3600, + "user": sessionData, + "expires_in": int(h.sessionManager.TTL().Seconds()), + }) +} + +// Me returns the currently authenticated user profile. +func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) { + sessionCookie, err := r.Cookie(sessionCookieName) + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + sessionData, err := h.sessionManager.GetSession(r.Context(), sessionCookie.Value) + if err != nil { + h.clearSessionCookie(w, r) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + user, err := h.authService.GetUserProfile(r.Context(), sessionData.UserID) + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + if err := h.sessionManager.RefreshSession(r.Context(), sessionCookie.Value); err == nil { + http.SetCookie(w, h.newSessionCookie(r, sessionCookie.Value)) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "user": user, + "expires_in": int(h.sessionManager.TTL().Seconds()), }) } @@ -292,7 +320,7 @@ func buildBackendURL(r *http.Request, path string) string { return scheme + "://" + r.Host + path } -func buildFrontendLoginURL(status, message, accessToken string, user *dto.UserDTO) string { +func buildFrontendLoginURL(status, message string) string { frontendURL := os.Getenv("FRONTEND_URL") if frontendURL == "" { frontendURL = "http://localhost:5173" @@ -310,14 +338,48 @@ func buildFrontendLoginURL(status, message, accessToken string, user *dto.UserDT if message != "" { query.Set("message", message) } - if accessToken != "" { - query.Set("access_token", accessToken) - } - if user != nil { - payload, _ := json.Marshal(user) - query.Set("user_json", string(payload)) - query.Set("user", base64.RawURLEncoding.EncodeToString(payload)) - } parsed.RawQuery = query.Encode() return parsed.String() } + +func (h *AuthHandler) setSessionCookie(w http.ResponseWriter, r *http.Request, user *dto.UserDTO) error { + if user == nil { + return nil + } + + sessionID, err := h.sessionManager.CreateSession(r.Context(), &auth.SessionData{ + UserID: user.ID, + Email: user.Email, + Username: user.Username, + }) + if err != nil { + return err + } + + http.SetCookie(w, h.newSessionCookie(r, sessionID)) + return nil +} + +func (h *AuthHandler) newSessionCookie(r *http.Request, sessionID string) *http.Cookie { + return &http.Cookie{ + Name: sessionCookieName, + Value: sessionID, + Path: "/", + MaxAge: int(h.sessionManager.TTL().Seconds()), + HttpOnly: true, + Secure: isSecureRequest(r), + SameSite: http.SameSiteLaxMode, + } +} + +func (h *AuthHandler) clearSessionCookie(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + Secure: isSecureRequest(r), + SameSite: http.SameSiteLaxMode, + }) +} diff --git a/backend/internal/interfaces/middleware/auth.go b/backend/internal/interfaces/middleware/auth.go index 8cd9b77..dc06724 100644 --- a/backend/internal/interfaces/middleware/auth.go +++ b/backend/internal/interfaces/middleware/auth.go @@ -20,13 +20,15 @@ const ( // AuthMiddleware verifies JWT tokens type AuthMiddleware struct { - jwtManager *auth.JWTManager + jwtManager *auth.JWTManager + sessionManager *auth.SessionManager } // NewAuthMiddleware creates a new auth middleware -func NewAuthMiddleware(jwtManager *auth.JWTManager) *AuthMiddleware { +func NewAuthMiddleware(jwtManager *auth.JWTManager, sessionManager *auth.SessionManager) *AuthMiddleware { return &AuthMiddleware{ - jwtManager: jwtManager, + jwtManager: jwtManager, + sessionManager: sessionManager, } } @@ -41,16 +43,23 @@ func (m *AuthMiddleware) Middleware(next http.Handler) http.Handler { return } - // Extract token from Authorization header. - // For GET /files/object, also accept ?token= so markdown images render in-browser. - authHeader := r.Header.Get("Authorization") - if authHeader == "" && r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/files/object") { - if tok := r.URL.Query().Get("token"); tok != "" { - authHeader = "Bearer " + tok + if sessionCookie, err := r.Cookie("session_id"); err == nil && sessionCookie.Value != "" { + sessionData, sessionErr := m.sessionManager.GetSession(r.Context(), sessionCookie.Value) + if sessionErr == nil { + _ = m.sessionManager.RefreshSession(r.Context(), sessionCookie.Value) + + ctx := context.WithValue(r.Context(), UserIDKey, sessionData.UserID) + ctx = context.WithValue(ctx, EmailKey, sessionData.Email) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + return } } + + // Fall back to Authorization header for backwards compatibility. + authHeader := r.Header.Get("Authorization") if authHeader == "" { - http.Error(w, "Missing authorization header", http.StatusUnauthorized) + http.Error(w, "Unauthorized", http.StatusUnauthorized) return } diff --git a/backend/internal/interfaces/middleware/security.go b/backend/internal/interfaces/middleware/security.go index 8ef9e66..6fe940d 100644 --- a/backend/internal/interfaces/middleware/security.go +++ b/backend/internal/interfaces/middleware/security.go @@ -79,6 +79,7 @@ func CORSMiddleware(next http.Handler) http.Handler { } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") + w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Max-Age", "600") if r.Method == http.MethodOptions { diff --git a/devops/docker/nginx.conf b/devops/docker/nginx.conf index 78e2b85..590d0a6 100644 --- a/devops/docker/nginx.conf +++ b/devops/docker/nginx.conf @@ -44,20 +44,6 @@ http { listen 80; server_name localhost; - # API routes - location /api/ { - limit_req zone=api_limit burst=20 nodelay; - - proxy_pass http://notely; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - proxy_connect_timeout 60s; - proxy_send_timeout 60s; - proxy_read_timeout 60s; - } - # Health check location /health { proxy_pass http://notely; diff --git a/docker-compose.yml b/docker-compose.yml index 071ee09..eee4ba8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,17 @@ -version: "3.8" - services: + redis: + image: redis:8-alpine + container_name: notely-redis + ports: + - "6379:6379" + networks: + - notely-network + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + mongodb: image: mongo:8.0 container_name: notely-mongodb @@ -39,9 +50,15 @@ services: DEFAULT_ADMIN_EMAIL: ${DEFAULT_ADMIN_EMAIL} DEFAULT_ADMIN_USERNAME: ${DEFAULT_ADMIN_USERNAME} DEFAULT_ADMIN_PASSWORD: ${DEFAULT_ADMIN_PASSWORD} + REDIS_ADDR: ${REDIS_ADDR} + REDIS_PASSWORD: ${REDIS_PASSWORD} + REDIS_DB: ${REDIS_DB} + SESSION_TTL_HOURS: ${SESSION_TTL_HOURS} depends_on: mongodb: condition: service_healthy + redis: + condition: service_healthy networks: - notely-network diff --git a/frontend/src/components/NoteEditor.vue b/frontend/src/components/NoteEditor.vue index 009b4a8..12b8efa 100644 --- a/frontend/src/components/NoteEditor.vue +++ b/frontend/src/components/NoteEditor.vue @@ -2,7 +2,6 @@
Deleting this note is permanent and cannot be undone.
+ +