229 lines
5.4 KiB
Go
229 lines
5.4 KiB
Go
package agentsync
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"runtime"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/mrhid6/keymanager/agent/internal/config"
|
|
grpcclient "github.com/mrhid6/keymanager/agent/internal/grpc"
|
|
"github.com/mrhid6/keymanager/agent/internal/grpc/pb"
|
|
"github.com/mrhid6/keymanager/agent/internal/keys"
|
|
)
|
|
|
|
func Run(ctx context.Context, cfg *config.Config) error {
|
|
client, err := grpcclient.New(cfg.ServerURL, cfg.TLS)
|
|
if err != nil {
|
|
return fmt.Errorf("dial grpc: %w", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
// Register if we have a pre-reg token
|
|
if cfg.PreRegToken != "" {
|
|
log.Println("registering with server...")
|
|
hostname, _ := os.Hostname()
|
|
ipAddress := localIP()
|
|
osInfo := fmt.Sprintf("%s %s", runtime.GOOS, runtime.GOARCH)
|
|
|
|
agentToken, err := client.Register(cfg.ServerID, cfg.PreRegToken, hostname, ipAddress, osInfo)
|
|
if err != nil {
|
|
return fmt.Errorf("registration failed: %w", err)
|
|
}
|
|
|
|
cfg.AgentToken = agentToken
|
|
cfg.PreRegToken = ""
|
|
if err := config.Save(cfg); err != nil {
|
|
return fmt.Errorf("save config: %w", err)
|
|
}
|
|
log.Println("registration successful")
|
|
|
|
client.Close()
|
|
client, err = grpcclient.New(cfg.ServerURL, cfg.TLS)
|
|
if err != nil {
|
|
return fmt.Errorf("reconnect: %w", err)
|
|
}
|
|
}
|
|
|
|
if cfg.AgentToken == "" {
|
|
return fmt.Errorf("no agent token available — registration required")
|
|
}
|
|
|
|
// Start the command stream alongside the poll loop.
|
|
go runCommandStream(ctx, cfg)
|
|
|
|
ticker := time.NewTicker(cfg.PollInterval)
|
|
defer ticker.Stop()
|
|
|
|
// Run immediately on startup
|
|
if err := poll(client, cfg); err != nil {
|
|
log.Printf("poll error: %v", err)
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case <-ticker.C:
|
|
if err := poll(client, cfg); err != nil {
|
|
log.Printf("poll error: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func poll(client *grpcclient.Client, cfg *config.Config) error {
|
|
desired, err := client.SyncKeys(cfg.ServerID, cfg.AgentToken)
|
|
if err != nil {
|
|
return fmt.Errorf("SyncKeys: %w", err)
|
|
}
|
|
|
|
current, err := keys.ReadAuthorizedKeys()
|
|
if err != nil {
|
|
return fmt.Errorf("read authorized_keys: %w", err)
|
|
}
|
|
|
|
if !keys.StateChanged(current, desired) {
|
|
log.Println("authorized_keys unchanged, skipping write")
|
|
return nil
|
|
}
|
|
|
|
if err := keys.WriteAuthorizedKeys(desired); err != nil {
|
|
return fmt.Errorf("write authorized_keys: %w", err)
|
|
}
|
|
log.Printf("authorized_keys updated (%d keys)", len(desired))
|
|
return nil
|
|
}
|
|
|
|
// runCommandStream maintains a persistent bidirectional stream with the server
|
|
// for instant command delivery. Reconnects with exponential backoff on failure.
|
|
func runCommandStream(ctx context.Context, cfg *config.Config) {
|
|
backoff := time.Second
|
|
const maxBackoff = 2 * time.Minute
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
if err := connectAndHandleStream(ctx, cfg); err != nil {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
log.Printf("command stream error: %v, reconnecting in %s", err, backoff)
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(backoff):
|
|
}
|
|
if backoff < maxBackoff {
|
|
backoff *= 2
|
|
}
|
|
} else {
|
|
backoff = time.Second
|
|
}
|
|
}
|
|
}
|
|
|
|
func connectAndHandleStream(ctx context.Context, cfg *config.Config) error {
|
|
client, err := grpcclient.New(cfg.ServerURL, cfg.TLS)
|
|
if err != nil {
|
|
return fmt.Errorf("dial: %w", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
stream, err := client.CommandStream(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("open stream: %w", err)
|
|
}
|
|
|
|
if err := stream.Send(&pb.AgentMessage{
|
|
ServerId: cfg.ServerID,
|
|
AgentToken: cfg.AgentToken,
|
|
Ready: &pb.AgentReady{},
|
|
}); err != nil {
|
|
return fmt.Errorf("send auth: %w", err)
|
|
}
|
|
|
|
log.Println("command stream connected")
|
|
|
|
for {
|
|
cmd, err := stream.Recv()
|
|
if err != nil {
|
|
return fmt.Errorf("recv: %w", err)
|
|
}
|
|
|
|
if cmd.GenerateKey != nil {
|
|
go handleGenerateKey(cfg, cmd)
|
|
}
|
|
}
|
|
}
|
|
|
|
func handleGenerateKey(cfg *config.Config, cmd *pb.ServerCommand) {
|
|
label := cmd.GenerateKey.Label
|
|
keyPath := fmt.Sprintf("/root/.ssh/keymanager_%s", strings.ReplaceAll(label, " ", "_"))
|
|
|
|
pubKey, err := keys.GenerateKeyPair(keyPath, label)
|
|
if err != nil {
|
|
log.Printf("key generation failed (cmd=%s): %v", cmd.CommandId, err)
|
|
return
|
|
}
|
|
|
|
client, err := grpcclient.New(cfg.ServerURL, cfg.TLS)
|
|
if err != nil {
|
|
log.Printf("dial for key upload failed (cmd=%s): %v", cmd.CommandId, err)
|
|
return
|
|
}
|
|
defer client.Close()
|
|
|
|
keyID, err := client.UploadGeneratedKey(cfg.ServerID, cfg.AgentToken, pubKey, label)
|
|
if err != nil {
|
|
log.Printf("key upload failed (cmd=%s): %v", cmd.CommandId, err)
|
|
return
|
|
}
|
|
log.Printf("generated and uploaded key %q (key_id=%s, cmd=%s)", label, keyID, cmd.CommandId)
|
|
}
|
|
|
|
func localIP() string {
|
|
addrs, err := net.InterfaceAddrs()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
for _, addr := range addrs {
|
|
if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
|
|
if ipNet.IP.To4() != nil {
|
|
return ipNet.IP.String()
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// GenerateAndUpload generates an SSH keypair and uploads the public key to the server.
|
|
func GenerateAndUpload(cfg *config.Config, label string) error {
|
|
client, err := grpcclient.New(cfg.ServerURL, cfg.TLS)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer client.Close()
|
|
|
|
keyPath := fmt.Sprintf("/root/.ssh/keymanager_%s", strings.ReplaceAll(label, " ", "_"))
|
|
pubKey, err := keys.GenerateKeyPair(keyPath, label)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
keyID, err := client.UploadGeneratedKey(cfg.ServerID, cfg.AgentToken, pubKey, label)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Printf("uploaded generated key %s (key_id=%s)", label, keyID)
|
|
return nil
|
|
}
|