@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mrhid6/keymanager/server/internal/auth"
|
||||
"github.com/mrhid6/keymanager/server/internal/models"
|
||||
"github.com/mrhid6/keymanager/server/internal/services"
|
||||
)
|
||||
@@ -13,21 +14,30 @@ import (
|
||||
func RegisterRoutes(r *gin.Engine) {
|
||||
r.GET("/install", handleInstallScript)
|
||||
|
||||
api := r.Group("/api")
|
||||
{
|
||||
api.GET("/servers", listServers)
|
||||
api.POST("/servers", createServer)
|
||||
api.GET("/servers/new", newServer)
|
||||
api.POST("/servers/new", newServer)
|
||||
api.GET("/servers/:id", getServer)
|
||||
api.DELETE("/servers/:id", deleteServer)
|
||||
api.POST("/servers/:id/generate-key", generateKey)
|
||||
// Auth endpoints (no session required)
|
||||
r.GET("/auth/login", auth.HandleLogin)
|
||||
r.GET("/auth/callback", auth.HandleCallback)
|
||||
r.GET("/auth/logout", auth.HandleLogout)
|
||||
r.GET("/auth/me", auth.HandleMe)
|
||||
|
||||
api.GET("/keys", listKeys)
|
||||
api.POST("/keys", createKey)
|
||||
api.GET("/keys/:id", getKey)
|
||||
api.POST("/keys/:id/assign", assignKey)
|
||||
api.DELETE("/keys/:id/assign/:serverId", revokeAssignment)
|
||||
// API endpoints protected by session middleware
|
||||
apiGroup := r.Group("/api")
|
||||
apiGroup.Use(auth.Middleware())
|
||||
{
|
||||
apiGroup.GET("/servers", listServers)
|
||||
apiGroup.POST("/servers", createServer)
|
||||
apiGroup.GET("/servers/new", newServer)
|
||||
apiGroup.POST("/servers/new", newServer)
|
||||
apiGroup.GET("/servers/:id", getServer)
|
||||
apiGroup.DELETE("/servers/:id", deleteServer)
|
||||
apiGroup.POST("/servers/:id/generate-key", generateKey)
|
||||
|
||||
apiGroup.GET("/keys", listKeys)
|
||||
apiGroup.POST("/keys", createKey)
|
||||
apiGroup.GET("/keys/:id", getKey)
|
||||
apiGroup.DELETE("/keys/:id", deleteKey)
|
||||
apiGroup.POST("/keys/:id/assign", assignKey)
|
||||
apiGroup.DELETE("/keys/:id/assign/:serverId", revokeAssignment)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,12 +173,26 @@ func getKey(c *gin.Context) {
|
||||
}
|
||||
|
||||
assignments, _ := services.GetAssignmentsWithServers(id)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"key": key,
|
||||
"assignments": assignments,
|
||||
|
||||
type keyResponse struct {
|
||||
*models.Key
|
||||
Assignments any `json:"assignments"`
|
||||
}
|
||||
c.JSON(http.StatusOK, keyResponse{
|
||||
Key: key,
|
||||
Assignments: assignments,
|
||||
})
|
||||
}
|
||||
|
||||
func deleteKey(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := services.DeleteKey(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func assignKey(c *gin.Context) {
|
||||
keyID := c.Param("id")
|
||||
var body struct {
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const ctxSessionKey = "km_session"
|
||||
|
||||
func GetSessionFromContext(c *gin.Context) *Session {
|
||||
v, _ := c.Get(ctxSessionKey)
|
||||
sess, _ := v.(*Session)
|
||||
return sess
|
||||
}
|
||||
|
||||
func Middleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !authEnabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
cookie, err := c.Request.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
sess, err := GetSession(c.Request.Context(), cookie.Value)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "session expired"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(ctxSessionKey, sess)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var (
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Cfg *oauth2.Config
|
||||
authEnabled bool
|
||||
)
|
||||
|
||||
func InitOIDC(ctx context.Context) error {
|
||||
issuer := os.Getenv("OIDC_ISSUER")
|
||||
if issuer == "" {
|
||||
log.Println("OIDC_ISSUER not set; authentication disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
p, err := oidc.NewProvider(ctx, issuer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oidcProvider = p
|
||||
oauth2Cfg = &oauth2.Config{
|
||||
ClientID: os.Getenv("OIDC_CLIENT_ID"),
|
||||
ClientSecret: os.Getenv("OIDC_CLIENT_SECRET"),
|
||||
RedirectURL: os.Getenv("OIDC_REDIRECT_URL"),
|
||||
Endpoint: p.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
authEnabled = true
|
||||
log.Println("OIDC authentication enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
func Enabled() bool { return authEnabled }
|
||||
|
||||
func HandleLogin(c *gin.Context) {
|
||||
state, err := randomHex(16)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "state generation failed"})
|
||||
return
|
||||
}
|
||||
if err := SaveState(c.Request.Context(), state); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "state save failed"})
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusFound, oauth2Cfg.AuthCodeURL(state))
|
||||
}
|
||||
|
||||
func HandleCallback(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
if !ConsumeState(ctx, c.Query("state")) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid state"})
|
||||
return
|
||||
}
|
||||
|
||||
token, err := oauth2Cfg.Exchange(ctx, c.Query("code"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "token exchange failed"})
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "missing id_token"})
|
||||
return
|
||||
}
|
||||
|
||||
verifier := oidcProvider.Verifier(&oidc.Config{ClientID: oauth2Cfg.ClientID})
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "token verification failed"})
|
||||
return
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "claims extraction failed"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := SaveSession(ctx, &Session{
|
||||
UserID: claims.Sub,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "session save failed"})
|
||||
return
|
||||
}
|
||||
|
||||
secure := c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https"
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: sessionID,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(sessionTTL.Seconds()),
|
||||
})
|
||||
|
||||
frontendURL := os.Getenv("PUBLIC_HOST")
|
||||
if frontendURL == "" {
|
||||
frontendURL = "/"
|
||||
}
|
||||
c.Redirect(http.StatusFound, frontendURL)
|
||||
}
|
||||
|
||||
func HandleLogout(c *gin.Context) {
|
||||
if cookie, err := c.Request.Cookie(sessionCookieName); err == nil {
|
||||
_ = DeleteSession(c.Request.Context(), cookie.Value)
|
||||
}
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
MaxAge: -1,
|
||||
})
|
||||
c.Redirect(http.StatusFound, "/")
|
||||
}
|
||||
|
||||
func HandleMe(c *gin.Context) {
|
||||
if !authEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{"auth_enabled": false})
|
||||
return
|
||||
}
|
||||
cookie, err := c.Request.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "not authenticated"})
|
||||
return
|
||||
}
|
||||
sess, err := GetSession(c.Request.Context(), cookie.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "session expired"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, sess)
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const sessionTTL = 24 * time.Hour
|
||||
const sessionCookieName = "km_session"
|
||||
const sessionPrefix = "km:session:"
|
||||
const statePrefix = "km:state:"
|
||||
|
||||
type Session struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
var rdb *redis.Client
|
||||
|
||||
func InitRedis(addr string) error {
|
||||
rdb = redis.NewClient(&redis.Options{Addr: addr})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
return rdb.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
func randomHex(n int) (string, error) {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func SaveSession(ctx context.Context, sess *Session) (string, error) {
|
||||
id, err := randomHex(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data, err := json.Marshal(sess)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := rdb.Set(ctx, sessionPrefix+id, data, sessionTTL).Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func GetSession(ctx context.Context, id string) (*Session, error) {
|
||||
data, err := rdb.Get(ctx, sessionPrefix+id).Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var sess Session
|
||||
if err := json.Unmarshal(data, &sess); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sess, nil
|
||||
}
|
||||
|
||||
func DeleteSession(ctx context.Context, id string) error {
|
||||
return rdb.Del(ctx, sessionPrefix+id).Err()
|
||||
}
|
||||
|
||||
func SaveState(ctx context.Context, state string) error {
|
||||
return rdb.Set(ctx, statePrefix+state, "1", 10*time.Minute).Err()
|
||||
}
|
||||
|
||||
func ConsumeState(ctx context.Context, state string) bool {
|
||||
n, err := rdb.Del(ctx, statePrefix+state).Result()
|
||||
return err == nil && n > 0
|
||||
}
|
||||
@@ -63,7 +63,12 @@ func GetKey(keyID string) (*models.Key, error) {
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func ListKeys() ([]models.Key, error) {
|
||||
type KeyWithCount struct {
|
||||
models.Key `bson:",inline"`
|
||||
AssignedCount int `bson:"-" json:"assigned_count"`
|
||||
}
|
||||
|
||||
func ListKeys() ([]KeyWithCount, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -77,7 +82,16 @@ func ListKeys() ([]models.Key, error) {
|
||||
if err := cursor.All(ctx, &keys); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
|
||||
result := make([]KeyWithCount, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
count, _ := db.Col("assignments").CountDocuments(ctx, bson.M{
|
||||
"key_id": k.KeyID,
|
||||
"revoked_at": nil,
|
||||
})
|
||||
result = append(result, KeyWithCount{Key: k, AssignedCount: int(count)})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func DeleteKey(keyID string) error {
|
||||
|
||||
Reference in New Issue
Block a user