first commit

This commit is contained in:
domrichardson
2026-03-24 16:03:04 +00:00
commit df40cc57e1
80 changed files with 16766 additions and 0 deletions

View File

@@ -0,0 +1,592 @@
package services
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/noteapp/backend/internal/application/dto"
"github.com/noteapp/backend/internal/domain/entities"
"github.com/noteapp/backend/internal/domain/repositories"
"github.com/noteapp/backend/internal/infrastructure/auth"
"github.com/noteapp/backend/internal/infrastructure/security"
"go.mongodb.org/mongo-driver/v2/bson"
"golang.org/x/oauth2"
)
// AuthService handles authentication operations
type AuthService struct {
userRepo repositories.UserRepository
groupRepo repositories.GroupRepository
providerRepo repositories.AuthProviderRepository
linkRepo repositories.UserProviderLinkRepository
recoveryRepo repositories.AccountRecoveryRepository
featureFlagRepo repositories.FeatureFlagRepository
permissionService *PermissionService
jwtManager *auth.JWTManager
passHasher *security.PasswordHasher
encryptor *security.Encryptor
}
// NewAuthService creates a new auth service
func NewAuthService(
userRepo repositories.UserRepository,
groupRepo repositories.GroupRepository,
providerRepo repositories.AuthProviderRepository,
linkRepo repositories.UserProviderLinkRepository,
recoveryRepo repositories.AccountRecoveryRepository,
featureFlagRepo repositories.FeatureFlagRepository,
permissionService *PermissionService,
jwtManager *auth.JWTManager,
passHasher *security.PasswordHasher,
encryptor *security.Encryptor,
) *AuthService {
return &AuthService{
userRepo: userRepo,
groupRepo: groupRepo,
providerRepo: providerRepo,
linkRepo: linkRepo,
recoveryRepo: recoveryRepo,
featureFlagRepo: featureFlagRepo,
permissionService: permissionService,
jwtManager: jwtManager,
passHasher: passHasher,
encryptor: encryptor,
}
}
// Register registers a new user
func (s *AuthService) Register(ctx context.Context, req *dto.RegisterRequest) (*dto.LoginResponse, error) {
flags, err := s.GetFeatureFlags(ctx)
if err != nil {
return nil, err
}
if !flags.RegistrationEnabled {
return nil, errors.New("registration is currently disabled")
}
req.Email = strings.ToLower(strings.TrimSpace(req.Email))
req.Username = strings.TrimSpace(req.Username)
// Check if email already exists
_, err = s.userRepo.GetUserByEmail(ctx, req.Email)
if err == nil {
return nil, errors.New("email already registered")
}
// Check if username already exists
_, err = s.userRepo.GetUserByUsername(ctx, req.Username)
if err == nil {
return nil, errors.New("username already taken")
}
// Hash password
hashedPassword, err := s.passHasher.HashPassword(req.Password)
if err != nil {
return nil, err
}
// Create user
user := &entities.User{
Email: req.Email,
Username: req.Username,
PasswordHash: hashedPassword,
FirstName: req.FirstName,
LastName: req.LastName,
IsActive: true,
EmailVerified: false, // Should verify email in production
}
if err := s.userRepo.CreateUser(ctx, user); err != nil {
return nil, err
}
if s.permissionService != nil {
if err := s.permissionService.UpdateUserEffectivePermissions(ctx, user); err != nil {
return nil, err
}
}
// Generate tokens
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
if err != nil {
return nil, err
}
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
if err != nil {
return nil, err
}
return &dto.LoginResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
User: dto.NewUserDTO(user),
ExpiresIn: 3600, // 1 hour
}, nil
}
// Login authenticates a user
func (s *AuthService) Login(ctx context.Context, req *dto.LoginRequest) (*dto.LoginResponse, error) {
req.Email = strings.ToLower(strings.TrimSpace(req.Email))
// Get user by email
user, err := s.userRepo.GetUserByEmail(ctx, req.Email)
if err != nil {
return nil, errors.New("invalid credentials")
}
if !user.IsActive {
return nil, errors.New("account is inactive")
}
// Verify password
match, err := s.passHasher.VerifyPassword(req.Password, user.PasswordHash)
if err != nil || !match {
return nil, errors.New("invalid credentials")
}
// Update last login
now := time.Now()
user.LastLoginAt = &now
if s.permissionService != nil {
if err := s.permissionService.UpdateUserEffectivePermissions(ctx, user); err != nil {
return nil, err
}
}
if err := s.userRepo.UpdateUser(ctx, user); err != nil {
// Log error but don't fail the login
}
// Generate tokens
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
if err != nil {
return nil, err
}
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
if err != nil {
return nil, err
}
return &dto.LoginResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
User: dto.NewUserDTO(user),
ExpiresIn: 3600,
}, nil
}
// RefreshAccessToken refreshes an access token
func (s *AuthService) RefreshAccessToken(ctx context.Context, refreshToken string) (string, error) {
claims, err := s.jwtManager.VerifyRefreshToken(refreshToken)
if err != nil {
return "", err
}
user, err := s.userRepo.GetUserByID(ctx, mustParseObjectID(claims.UserID))
if err != nil {
return "", err
}
return s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
}
// RequestPasswordReset initiates password reset flow
func (s *AuthService) RequestPasswordReset(ctx context.Context, email string) error {
user, err := s.userRepo.GetUserByEmail(ctx, email)
if err != nil {
// Don't reveal if email exists (security best practice)
return nil
}
token, err := auth.GenerateRandomToken(32)
if err != nil {
return err
}
recovery := &entities.AccountRecovery{
UserID: user.ID,
Token: token,
Type: "password_reset",
ExpiresAt: time.Now().Add(1 * time.Hour),
}
// Save recovery token
// This would need AccountRecoveryRepository implementation
_ = recovery
// In production: send email with reset link containing token
return nil
}
// mustParseObjectID parses a string to ObjectID, panics on error
func mustParseObjectID(id string) bson.ObjectID {
objID, _ := bson.ObjectIDFromHex(id)
return objID
}
// ListProviders returns all active OAuth/OIDC providers.
func (s *AuthService) ListProviders(ctx context.Context) ([]*dto.AuthProviderDTO, error) {
flags, err := s.GetFeatureFlags(ctx)
if err != nil {
return nil, err
}
if !flags.ProviderLoginEnabled {
return []*dto.AuthProviderDTO{}, nil
}
if s.providerRepo == nil {
return []*dto.AuthProviderDTO{}, nil
}
providers, err := s.providerRepo.GetAllProviders(ctx)
if err != nil {
return nil, err
}
result := make([]*dto.AuthProviderDTO, 0, len(providers))
for _, provider := range providers {
result = append(result, dto.NewAuthProviderDTO(provider))
}
return result, nil
}
// GetFeatureFlags returns current app-wide feature flags.
func (s *AuthService) GetFeatureFlags(ctx context.Context) (*dto.FeatureFlagsDTO, error) {
if s.featureFlagRepo == nil {
return dto.NewFeatureFlagsDTO(nil), nil
}
flags, err := s.featureFlagRepo.GetFeatureFlags(ctx)
if err != nil {
return nil, err
}
return dto.NewFeatureFlagsDTO(flags), nil
}
// CreateProvider stores a new OAuth/OIDC provider.
func (s *AuthService) CreateProvider(ctx context.Context, req *dto.CreateAuthProviderRequest) (*dto.AuthProviderDTO, error) {
if s.providerRepo == nil || s.encryptor == nil {
return nil, errors.New("provider configuration unavailable")
}
providerType := strings.ToLower(strings.TrimSpace(req.Type))
if providerType != "oidc" && providerType != "oauth2" {
return nil, errors.New("provider type must be oidc or oauth2")
}
name := strings.TrimSpace(req.Name)
clientID := strings.TrimSpace(req.ClientID)
clientSecret := strings.TrimSpace(req.ClientSecret)
authorizationURL := strings.TrimSpace(req.AuthorizationURL)
tokenURL := strings.TrimSpace(req.TokenURL)
if name == "" || clientID == "" || clientSecret == "" || authorizationURL == "" || tokenURL == "" {
return nil, errors.New("missing required provider fields")
}
encryptedSecret, err := s.encryptor.Encrypt(clientSecret)
if err != nil {
return nil, err
}
provider := &entities.AuthProvider{
Name: name,
Type: providerType,
ClientID: clientID,
ClientSecret: encryptedSecret,
AuthorizationURL: authorizationURL,
TokenURL: tokenURL,
UserInfoURL: strings.TrimSpace(req.UserInfoURL),
Scopes: normalizeScopes(req.Scopes, providerType),
IDTokenClaim: strings.TrimSpace(req.IDTokenClaim),
IsActive: req.IsActive,
}
if err := s.providerRepo.CreateProvider(ctx, provider); err != nil {
return nil, err
}
return dto.NewAuthProviderDTO(provider), nil
}
// BuildProviderAuthorizationURL constructs a provider authorization URL.
func (s *AuthService) BuildProviderAuthorizationURL(ctx context.Context, providerID bson.ObjectID, redirectURI, state string) (string, error) {
flags, err := s.GetFeatureFlags(ctx)
if err != nil {
return "", err
}
if !flags.ProviderLoginEnabled {
return "", errors.New("provider login is currently disabled")
}
provider, secret, err := s.getProviderConfig(ctx, providerID)
if err != nil {
return "", err
}
config := oauth2.Config{
ClientID: provider.ClientID,
ClientSecret: secret,
RedirectURL: redirectURI,
Scopes: normalizeScopes(provider.Scopes, provider.Type),
Endpoint: oauth2.Endpoint{
AuthURL: provider.AuthorizationURL,
TokenURL: provider.TokenURL,
},
}
return config.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
}
// CompleteProviderLogin exchanges an auth code and creates a user session.
func (s *AuthService) CompleteProviderLogin(ctx context.Context, providerID bson.ObjectID, code, redirectURI string) (*dto.LoginResponse, error) {
if s.providerRepo == nil || s.linkRepo == nil {
return nil, errors.New("provider login unavailable")
}
flags, err := s.GetFeatureFlags(ctx)
if err != nil {
return nil, err
}
if !flags.ProviderLoginEnabled {
return nil, errors.New("provider login is currently disabled")
}
provider, secret, err := s.getProviderConfig(ctx, providerID)
if err != nil {
return nil, err
}
config := oauth2.Config{
ClientID: provider.ClientID,
ClientSecret: secret,
RedirectURL: redirectURI,
Scopes: normalizeScopes(provider.Scopes, provider.Type),
Endpoint: oauth2.Endpoint{
AuthURL: provider.AuthorizationURL,
TokenURL: provider.TokenURL,
},
}
token, err := config.Exchange(ctx, code)
if err != nil {
return nil, err
}
profile, err := s.fetchProviderProfile(ctx, provider, token.AccessToken, token.Extra(provider.IDTokenClaim))
if err != nil {
return nil, err
}
user, err := s.findOrCreateOAuthUser(ctx, provider, profile)
if err != nil {
return nil, err
}
accessToken, err := s.jwtManager.GenerateAccessToken(user.ID.Hex(), user.Email, user.Username)
if err != nil {
return nil, err
}
refreshToken, err := s.jwtManager.GenerateRefreshToken(user.ID.Hex())
if err != nil {
return nil, err
}
return &dto.LoginResponse{AccessToken: accessToken, RefreshToken: refreshToken, User: dto.NewUserDTO(user), ExpiresIn: 3600}, nil
}
type providerProfile struct {
ProviderUserID string
Email string
Username string
FirstName string
LastName string
}
func (s *AuthService) getProviderConfig(ctx context.Context, providerID bson.ObjectID) (*entities.AuthProvider, string, error) {
provider, err := s.providerRepo.GetProviderByID(ctx, providerID)
if err != nil {
return nil, "", err
}
if !provider.IsActive {
return nil, "", errors.New("provider is inactive")
}
secret, err := s.encryptor.Decrypt(provider.ClientSecret)
if err != nil {
return nil, "", err
}
return provider, secret, nil
}
func (s *AuthService) fetchProviderProfile(ctx context.Context, provider *entities.AuthProvider, accessToken string, rawIDToken any) (*providerProfile, error) {
payload := map[string]any{}
if provider.UserInfoURL != "" {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, provider.UserInfoURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("provider userinfo request failed: %s", string(body))
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return nil, err
}
} else if idToken, ok := rawIDToken.(string); ok && idToken != "" {
payload = decodeJWTWithoutVerify(idToken)
} else {
return nil, errors.New("provider must define userinfo_url or return id_token")
}
profile := &providerProfile{
ProviderUserID: firstNonEmpty(asString(payload["sub"]), asString(payload["id"]), asString(payload["user_id"])),
Email: strings.ToLower(strings.TrimSpace(firstNonEmpty(asString(payload["email"]), asString(payload["upn"])))),
Username: firstNonEmpty(asString(payload["preferred_username"]), asString(payload["login"]), asString(payload["name"])),
FirstName: firstNonEmpty(asString(payload["given_name"]), asString(payload["first_name"])),
LastName: firstNonEmpty(asString(payload["family_name"]), asString(payload["last_name"])),
}
if profile.ProviderUserID == "" {
return nil, errors.New("provider user info missing subject identifier")
}
if profile.Email == "" {
profile.Email = fmt.Sprintf("%s@%s.oauth.local", sanitizeUsername(profile.ProviderUserID), sanitizeUsername(provider.Name))
}
if profile.Username == "" {
profile.Username = strings.Split(profile.Email, "@")[0]
}
profile.Username = sanitizeUsername(profile.Username)
return profile, nil
}
func (s *AuthService) findOrCreateOAuthUser(ctx context.Context, provider *entities.AuthProvider, profile *providerProfile) (*entities.User, error) {
if link, err := s.linkRepo.GetLinkByProviderUserID(ctx, provider.ID, profile.ProviderUserID); err == nil {
return s.userRepo.GetUserByID(ctx, link.UserID)
}
user, err := s.userRepo.GetUserByEmail(ctx, profile.Email)
if err != nil {
username, err := s.generateUniqueUsername(ctx, profile.Username)
if err != nil {
return nil, err
}
user = &entities.User{Email: profile.Email, Username: username, PasswordHash: "", FirstName: profile.FirstName, LastName: profile.LastName, IsActive: true, EmailVerified: true}
if err := s.userRepo.CreateUser(ctx, user); err != nil {
return nil, err
}
}
if _, err := s.linkRepo.GetLink(ctx, user.ID, provider.ID); err != nil {
if err := s.linkRepo.CreateLink(ctx, &entities.UserProviderLink{UserID: user.ID, ProviderID: provider.ID, ProviderUserID: profile.ProviderUserID, Email: profile.Email}); err != nil {
return nil, err
}
}
return user, nil
}
func (s *AuthService) generateUniqueUsername(ctx context.Context, base string) (string, error) {
base = sanitizeUsername(base)
candidates := []string{base}
for i := 0; i < 5; i++ {
token, err := auth.GenerateRandomToken(2)
if err != nil {
return "", err
}
candidates = append(candidates, fmt.Sprintf("%s-%s", base, token[:4]))
}
for _, candidate := range candidates {
if _, err := s.userRepo.GetUserByUsername(ctx, candidate); err != nil {
return candidate, nil
}
}
return fmt.Sprintf("%s-%d", base, time.Now().Unix()), nil
}
func normalizeScopes(scopes []string, providerType string) []string {
if len(scopes) == 0 {
if providerType == "oidc" {
return []string{"openid", "profile", "email"}
}
return []string{"profile", "email"}
}
result := make([]string, 0, len(scopes))
for _, scope := range scopes {
scope = strings.TrimSpace(scope)
if scope != "" {
result = append(result, scope)
}
}
return result
}
func decodeJWTWithoutVerify(token string) map[string]any {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return map[string]any{}
}
decoded, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return map[string]any{}
}
claims := map[string]any{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return map[string]any{}
}
return claims
}
func asString(value any) string {
if str, ok := value.(string); ok {
return strings.TrimSpace(str)
}
return ""
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if value != "" {
return value
}
}
return ""
}
func sanitizeUsername(value string) string {
cleaned := regexp.MustCompile(`[^a-zA-Z0-9_-]+`).ReplaceAllString(strings.ToLower(strings.TrimSpace(value)), "-")
cleaned = strings.Trim(cleaned, "-")
if cleaned == "" {
return "user"
}
return cleaned
}