package services import ( "context" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" "regexp" "strings" "time" "github.com/noteapp/backend/internal/application/dto" "github.com/noteapp/backend/internal/domain/entities" "github.com/noteapp/backend/internal/domain/repositories" "github.com/noteapp/backend/internal/infrastructure/auth" "github.com/noteapp/backend/internal/infrastructure/security" "go.mongodb.org/mongo-driver/v2/bson" "golang.org/x/oauth2" ) // AuthService handles authentication operations type AuthService struct { userRepo repositories.UserRepository groupRepo repositories.GroupRepository providerRepo repositories.AuthProviderRepository linkRepo repositories.UserProviderLinkRepository recoveryRepo repositories.AccountRecoveryRepository featureFlagRepo repositories.FeatureFlagRepository permissionService *PermissionService jwtManager *auth.JWTManager passHasher *security.PasswordHasher encryptor *security.Encryptor } // NewAuthService creates a new auth service func NewAuthService( userRepo repositories.UserRepository, groupRepo repositories.GroupRepository, providerRepo repositories.AuthProviderRepository, linkRepo repositories.UserProviderLinkRepository, recoveryRepo repositories.AccountRecoveryRepository, featureFlagRepo repositories.FeatureFlagRepository, permissionService *PermissionService, jwtManager *auth.JWTManager, passHasher *security.PasswordHasher, encryptor *security.Encryptor, ) *AuthService { return &AuthService{ userRepo: userRepo, groupRepo: groupRepo, providerRepo: providerRepo, linkRepo: linkRepo, recoveryRepo: recoveryRepo, featureFlagRepo: featureFlagRepo, permissionService: permissionService, jwtManager: jwtManager, passHasher: passHasher, encryptor: encryptor, } } // Register registers a new user func (s *AuthService) Register(ctx context.Context, req *dto.RegisterRequest) (*dto.LoginResponse, error) { flags, err := s.GetFeatureFlags(ctx) if err != nil { return nil, err } if !flags.RegistrationEnabled { return nil, errors.New("registration is currently disabled") } req.Email = strings.ToLower(strings.TrimSpace(req.Email)) req.Username = strings.TrimSpace(req.Username) // Check if email already exists _, err = s.userRepo.GetUserByEmail(ctx, req.Email) if err == nil { return nil, errors.New("email already registered") } // Check if username already exists _, err = s.userRepo.GetUserByUsername(ctx, req.Username) if err == nil { return nil, errors.New("username already taken") } // Hash password hashedPassword, err := s.passHasher.HashPassword(req.Password) if err != nil { return nil, err } // Create user user := &entities.User{ Email: req.Email, Username: req.Username, PasswordHash: hashedPassword, FirstName: req.FirstName, LastName: req.LastName, IsActive: true, EmailVerified: false, // Should verify email in production } if err := s.userRepo.CreateUser(ctx, user); err != nil { return nil, err } if s.permissionService != nil { if err := s.permissionService.UpdateUserEffectivePermissions(ctx, user); err != nil { return nil, err } } // Generate tokens accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username) if err != nil { return nil, err } refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex()) if err != nil { return nil, err } return &dto.LoginResponse{ AccessToken: accessToken, RefreshToken: refreshToken, User: dto.NewUserDTO(user), ExpiresIn: 3600, // 1 hour }, nil } // Login authenticates a user func (s *AuthService) Login(ctx context.Context, req *dto.LoginRequest) (*dto.LoginResponse, error) { req.Email = strings.ToLower(strings.TrimSpace(req.Email)) // Get user by email user, err := s.userRepo.GetUserByEmail(ctx, req.Email) if err != nil { return nil, errors.New("invalid credentials") } if !user.IsActive { return nil, errors.New("account is inactive") } // Verify password match, err := s.passHasher.VerifyPassword(req.Password, user.PasswordHash) if err != nil || !match { return nil, errors.New("invalid credentials") } // Update last login now := time.Now() user.LastLoginAt = &now if s.permissionService != nil { if err := s.permissionService.UpdateUserEffectivePermissions(ctx, user); err != nil { return nil, err } } if err := s.userRepo.UpdateUser(ctx, user); err != nil { // Log error but don't fail the login } // Generate tokens accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username) if err != nil { return nil, err } refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex()) if err != nil { return nil, err } return &dto.LoginResponse{ AccessToken: accessToken, RefreshToken: refreshToken, User: dto.NewUserDTO(user), ExpiresIn: 3600, }, nil } // RefreshAccessToken refreshes an access token func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken string) (string, error) { claims, err := s.jwtManager.VerifyRefreshToken(refreshToken) if err != nil { return "", err } user, err := s.userRepo.GetUserByID(ctx, mustParseObjectID(claims.UserID)) if err != nil { return "", err } return s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username) } // RequestPasswordReset initiates password reset flow func (s *AuthService) RequestPasswordReset(ctx context.Context, email string) error { user, err := s.userRepo.GetUserByEmail(ctx, email) if err != nil { // Don't reveal if email exists (security best practice) return nil } token, err := auth.GenerateRandomToken(32) if err != nil { return err } recovery := &entities.AccountRecovery{ UserID: user.ID, Token: token, Type: "password_reset", ExpiresAt: time.Now().Add(1 * time.Hour), } // Save recovery token // This would need AccountRecoveryRepository implementation _ = recovery // In production: send email with reset link containing token return nil } // mustParseObjectID parses a string to ObjectID, panics on error func mustParseObjectID(id string) bson.ObjectID { objID, _ := bson.ObjectIDFromHex(id) return objID } // ListProviders returns all active OAuth/OIDC providers. func (s *AuthService) ListProviders(ctx context.Context) ([]*dto.AuthProviderDTO, error) { flags, err := s.GetFeatureFlags(ctx) if err != nil { return nil, err } if !flags.ProviderLoginEnabled { return []*dto.AuthProviderDTO{}, nil } if s.providerRepo == nil { return []*dto.AuthProviderDTO{}, nil } providers, err := s.providerRepo.GetAllProviders(ctx) if err != nil { return nil, err } result := make([]*dto.AuthProviderDTO, 0, len(providers)) for _, provider := range providers { result = append(result, dto.NewAuthProviderDTO(provider)) } return result, nil } // GetFeatureFlags returns current app-wide feature flags. func (s *AuthService) GetFeatureFlags(ctx context.Context) (*dto.FeatureFlagsDTO, error) { if s.featureFlagRepo == nil { return dto.NewFeatureFlagsDTO(nil), nil } flags, err := s.featureFlagRepo.GetFeatureFlags(ctx) if err != nil { return nil, err } return dto.NewFeatureFlagsDTO(flags), nil } // CreateProvider stores a new OAuth/OIDC provider. func (s *AuthService) CreateProvider(ctx context.Context, req *dto.CreateAuthProviderRequest) (*dto.AuthProviderDTO, error) { if s.providerRepo == nil || s.encryptor == nil { return nil, errors.New("provider configuration unavailable") } providerType := strings.ToLower(strings.TrimSpace(req.Type)) if providerType != "oidc" && providerType != "oauth2" { return nil, errors.New("provider type must be oidc or oauth2") } name := strings.TrimSpace(req.Name) clientID := strings.TrimSpace(req.ClientID) clientSecret := strings.TrimSpace(req.ClientSecret) authorizationURL := strings.TrimSpace(req.AuthorizationURL) tokenURL := strings.TrimSpace(req.TokenURL) if name == "" || clientID == "" || clientSecret == "" || authorizationURL == "" || tokenURL == "" { return nil, errors.New("missing required provider fields") } encryptedSecret, err := s.encryptor.Encrypt(clientSecret) if err != nil { return nil, err } provider := &entities.AuthProvider{ Name: name, Type: providerType, ClientID: clientID, ClientSecret: encryptedSecret, AuthorizationURL: authorizationURL, TokenURL: tokenURL, UserInfoURL: strings.TrimSpace(req.UserInfoURL), Scopes: normalizeScopes(req.Scopes, providerType), IDTokenClaim: strings.TrimSpace(req.IDTokenClaim), IsActive: req.IsActive, } if err := s.providerRepo.CreateProvider(ctx, provider); err != nil { return nil, err } return dto.NewAuthProviderDTO(provider), nil } // UpdateProvider updates an existing OAuth/OIDC provider. // If ClientSecret is empty, the existing encrypted secret is preserved. func (s *AuthService) UpdateProvider(ctx context.Context, providerID bson.ObjectID, req *dto.UpdateAuthProviderRequest) (*dto.AuthProviderDTO, error) { if s.providerRepo == nil || s.encryptor == nil { return nil, errors.New("provider configuration unavailable") } existing, err := s.providerRepo.GetProviderByID(ctx, providerID) if err != nil { return nil, err } providerType := strings.ToLower(strings.TrimSpace(req.Type)) if providerType != "oidc" && providerType != "oauth2" { return nil, errors.New("provider type must be oidc or oauth2") } name := strings.TrimSpace(req.Name) clientID := strings.TrimSpace(req.ClientID) authorizationURL := strings.TrimSpace(req.AuthorizationURL) tokenURL := strings.TrimSpace(req.TokenURL) if name == "" || clientID == "" || authorizationURL == "" || tokenURL == "" { return nil, errors.New("missing required provider fields") } existing.Name = name existing.Type = providerType existing.ClientID = clientID existing.AuthorizationURL = authorizationURL existing.TokenURL = tokenURL existing.UserInfoURL = strings.TrimSpace(req.UserInfoURL) existing.Scopes = normalizeScopes(req.Scopes, providerType) existing.IDTokenClaim = strings.TrimSpace(req.IDTokenClaim) existing.IsActive = req.IsActive clientSecret := strings.TrimSpace(req.ClientSecret) if clientSecret != "" { encrypted, err := s.encryptor.Encrypt(clientSecret) if err != nil { return nil, err } existing.ClientSecret = encrypted } if err := s.providerRepo.UpdateProvider(ctx, existing); err != nil { return nil, err } return dto.NewAuthProviderDTO(existing), nil } // BuildProviderAuthorizationURL constructs a provider authorization URL. func (s *AuthService) BuildProviderAuthorizationURL(ctx context.Context, providerID bson.ObjectID, redirectURI, state string) (string, error) { flags, err := s.GetFeatureFlags(ctx) if err != nil { return "", err } if !flags.ProviderLoginEnabled { return "", errors.New("provider login is currently disabled") } provider, secret, err := s.getProviderConfig(ctx, providerID) if err != nil { return "", err } config := oauth2.Config{ ClientID: provider.ClientID, ClientSecret: secret, RedirectURL: redirectURI, Scopes: normalizeScopes(provider.Scopes, provider.Type), Endpoint: oauth2.Endpoint{ AuthURL: provider.AuthorizationURL, TokenURL: provider.TokenURL, }, } return config.AuthCodeURL(state, oauth2.AccessTypeOffline), nil } // CompleteProviderLogin exchanges an auth code and creates a user session. func (s *AuthService) CompleteProviderLogin(ctx context.Context, providerID bson.ObjectID, code, redirectURI string) (*dto.LoginResponse, error) { if s.providerRepo == nil || s.linkRepo == nil { return nil, errors.New("provider login unavailable") } flags, err := s.GetFeatureFlags(ctx) if err != nil { return nil, err } if !flags.ProviderLoginEnabled { return nil, errors.New("provider login is currently disabled") } provider, secret, err := s.getProviderConfig(ctx, providerID) if err != nil { return nil, err } config := oauth2.Config{ ClientID: provider.ClientID, ClientSecret: secret, RedirectURL: redirectURI, Scopes: normalizeScopes(provider.Scopes, provider.Type), Endpoint: oauth2.Endpoint{ AuthURL: provider.AuthorizationURL, TokenURL: provider.TokenURL, }, } token, err := config.Exchange(ctx, code) if err != nil { return nil, err } profile, err := s.fetchProviderProfile(ctx, provider, token.AccessToken, token.Extra(provider.IDTokenClaim)) if err != nil { return nil, err } user, err := s.findOrCreateOAuthUser(ctx, provider, profile) if err != nil { return nil, err } accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username) if err != nil { return nil, err } refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex()) if err != nil { return nil, err } return &dto.LoginResponse{AccessToken: accessToken, RefreshToken: refreshToken, User: dto.NewUserDTO(user), ExpiresIn: 3600}, nil } type providerProfile struct { ProviderUserID string Email string Username string FirstName string LastName string } func (s *AuthService) getProviderConfig(ctx context.Context, providerID bson.ObjectID) (*entities.AuthProvider, string, error) { provider, err := s.providerRepo.GetProviderByID(ctx, providerID) if err != nil { return nil, "", err } if !provider.IsActive { return nil, "", errors.New("provider is inactive") } secret, err := s.encryptor.Decrypt(provider.ClientSecret) if err != nil { return nil, "", err } return provider, secret, nil } func (s *AuthService) fetchProviderProfile(ctx context.Context, provider *entities.AuthProvider, accessToken string, rawIDToken any) (*providerProfile, error) { payload := map[string]any{} if provider.UserInfoURL != "" { req, err := http.NewRequestWithContext(ctx, http.MethodGet, provider.UserInfoURL, nil) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode >= 400 { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("provider userinfo request failed: %s", string(body)) } if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { return nil, err } } else if idToken, ok := rawIDToken.(string); ok && idToken != "" { payload = decodeJWTWithoutVerify(idToken) } else { return nil, errors.New("provider must define userinfo_url or return id_token") } profile := &providerProfile{ ProviderUserID: firstNonEmpty(asString(payload["sub"]), asString(payload["id"]), asString(payload["user_id"])), Email: strings.ToLower(strings.TrimSpace(firstNonEmpty(asString(payload["email"]), asString(payload["upn"])))), Username: firstNonEmpty(asString(payload["preferred_username"]), asString(payload["login"]), asString(payload["name"])), FirstName: firstNonEmpty(asString(payload["given_name"]), asString(payload["first_name"])), LastName: firstNonEmpty(asString(payload["family_name"]), asString(payload["last_name"])), } if profile.ProviderUserID == "" { return nil, errors.New("provider user info missing subject identifier") } if profile.Email == "" { profile.Email = fmt.Sprintf("%s@%s.oauth.local", sanitizeUsername(profile.ProviderUserID), sanitizeUsername(provider.Name)) } if profile.Username == "" { profile.Username = strings.Split(profile.Email, "@")[0] } profile.Username = sanitizeUsername(profile.Username) return profile, nil } func (s *AuthService) findOrCreateOAuthUser(ctx context.Context, provider *entities.AuthProvider, profile *providerProfile) (*entities.User, error) { if link, err := s.linkRepo.GetLinkByProviderUserID(ctx, provider.ID, profile.ProviderUserID); err == nil { return s.userRepo.GetUserByID(ctx, link.UserID) } user, err := s.userRepo.GetUserByEmail(ctx, profile.Email) if err != nil { username, err := s.generateUniqueUsername(ctx, profile.Username) if err != nil { return nil, err } user = &entities.User{Email: profile.Email, Username: username, PasswordHash: "", FirstName: profile.FirstName, LastName: profile.LastName, IsActive: true, EmailVerified: true} if err := s.userRepo.CreateUser(ctx, user); err != nil { return nil, err } } if _, err := s.linkRepo.GetLink(ctx, user.ID, provider.ID); err != nil { if err := s.linkRepo.CreateLink(ctx, &entities.UserProviderLink{UserID: user.ID, ProviderID: provider.ID, ProviderUserID: profile.ProviderUserID, Email: profile.Email}); err != nil { return nil, err } } return user, nil } func (s *AuthService) generateUniqueUsername(ctx context.Context, base string) (string, error) { base = sanitizeUsername(base) candidates := []string{base} for i := 0; i < 5; i++ { token, err := auth.GenerateRandomToken(2) if err != nil { return "", err } candidates = append(candidates, fmt.Sprintf("%s-%s", base, token[:4])) } for _, candidate := range candidates { if _, err := s.userRepo.GetUserByUsername(ctx, candidate); err != nil { return candidate, nil } } return fmt.Sprintf("%s-%d", base, time.Now().Unix()), nil } func normalizeScopes(scopes []string, providerType string) []string { if len(scopes) == 0 { if providerType == "oidc" { return []string{"openid", "profile", "email"} } return []string{"profile", "email"} } result := make([]string, 0, len(scopes)) for _, scope := range scopes { scope = strings.TrimSpace(scope) if scope != "" { result = append(result, scope) } } return result } func decodeJWTWithoutVerify(token string) map[string]any { parts := strings.Split(token, ".") if len(parts) < 2 { return map[string]any{} } decoded, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return map[string]any{} } claims := map[string]any{} if err := json.Unmarshal(decoded, &claims); err != nil { return map[string]any{} } return claims } func asString(value any) string { if str, ok := value.(string); ok { return strings.TrimSpace(str) } return "" } func firstNonEmpty(values ...string) string { for _, value := range values { if value != "" { return value } } return "" } func sanitizeUsername(value string) string { cleaned := regexp.MustCompile(`[^a-zA-Z0-9_-]+`).ReplaceAllString(strings.ToLower(strings.TrimSpace(value)), "-") cleaned = strings.Trim(cleaned, "-") if cleaned == "" { return "user" } return cleaned }