diff --git a/agent/cmd/main.go b/agent/cmd/main.go index 0bda460..865f601 100644 --- a/agent/cmd/main.go +++ b/agent/cmd/main.go @@ -1,8 +1,11 @@ package main import ( + "context" "flag" "log" + "os/signal" + "syscall" "github.com/mrhid6/keymanager/agent/internal/config" agentsync "github.com/mrhid6/keymanager/agent/internal/sync" @@ -26,8 +29,11 @@ func main() { return } + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + log.Printf("keymanager-agent %s starting (server=%s, poll=%s)", Version, cfg.ServerURL, cfg.PollInterval) - if err := agentsync.Run(cfg); err != nil { + if err := agentsync.Run(ctx, cfg); err != nil { log.Fatalf("agent error: %v", err) } } diff --git a/agent/internal/grpc/client.go b/agent/internal/grpc/client.go index 306e49b..55e047e 100644 --- a/agent/internal/grpc/client.go +++ b/agent/internal/grpc/client.go @@ -102,3 +102,9 @@ func (c *Client) UploadGeneratedKey(serverID, agentToken, publicKey, label strin } return resp.KeyId, nil } + +// CommandStream opens a long-lived bidirectional stream for server-pushed commands. +// The caller controls the stream lifetime via ctx. +func (c *Client) CommandStream(ctx context.Context) (pb.KeyManager_CommandStreamClient, error) { + return c.client.CommandStream(ctx) +} diff --git a/agent/internal/grpc/pb/keymanager.pb.go b/agent/internal/grpc/pb/keymanager.pb.go index 3a209af..79aecbd 100644 --- a/agent/internal/grpc/pb/keymanager.pb.go +++ b/agent/internal/grpc/pb/keymanager.pb.go @@ -42,10 +42,85 @@ type UploadKeyResponse struct { KeyId string `json:"key_id"` } +// CommandStream message types + +type ServerCommand struct { + CommandId string `json:"command_id"` + GenerateKey *GenerateKeyCmd `json:"generate_key,omitempty"` +} + +type GenerateKeyCmd struct { + Label string `json:"label"` +} + +type AgentMessage struct { + ServerId string `json:"server_id"` + AgentToken string `json:"agent_token"` + Ready *AgentReady `json:"ready,omitempty"` + Result *CommandResult `json:"result,omitempty"` +} + +type AgentReady struct{} + +type CommandResult struct { + CommandId string `json:"command_id"` + Success bool `json:"success"` + Message string `json:"message"` +} + +// CommandStream client-side interface + +type KeyManager_CommandStreamClient interface { + Send(*AgentMessage) error + Recv() (*ServerCommand, error) + grpc.ClientStream +} + +type keyManagerCommandStreamClient struct { + grpc.ClientStream +} + +func (c *keyManagerCommandStreamClient) Send(m *AgentMessage) error { + return c.ClientStream.SendMsg(m) +} + +func (c *keyManagerCommandStreamClient) Recv() (*ServerCommand, error) { + m := new(ServerCommand) + if err := c.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// CommandStream server-side interface (included for completeness) + +type KeyManager_CommandStreamServer interface { + Send(*ServerCommand) error + Recv() (*AgentMessage, error) + grpc.ServerStream +} + +type keyManagerCommandStreamServer struct { + grpc.ServerStream +} + +func (s *keyManagerCommandStreamServer) Send(m *ServerCommand) error { + return s.ServerStream.SendMsg(m) +} + +func (s *keyManagerCommandStreamServer) Recv() (*AgentMessage, error) { + m := new(AgentMessage) + if err := s.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + type KeyManagerClient interface { Register(ctx context.Context, in *RegisterRequest, opts ...grpc.CallOption) (*RegisterResponse, error) SyncKeys(ctx context.Context, in *SyncRequest, opts ...grpc.CallOption) (*SyncResponse, error) UploadGeneratedKey(ctx context.Context, in *UploadKeyRequest, opts ...grpc.CallOption) (*UploadKeyResponse, error) + CommandStream(ctx context.Context, opts ...grpc.CallOption) (KeyManager_CommandStreamClient, error) } type UnimplementedKeyManagerServer struct{} @@ -91,3 +166,12 @@ func (c *keyManagerClient) UploadGeneratedKey(ctx context.Context, in *UploadKey } return out, nil } + +func (c *keyManagerClient) CommandStream(ctx context.Context, opts ...grpc.CallOption) (KeyManager_CommandStreamClient, error) { + desc := &grpc.StreamDesc{StreamName: "CommandStream", ServerStreams: true, ClientStreams: true} + stream, err := c.cc.NewStream(ctx, desc, "/keymanager.v1.KeyManager/CommandStream", opts...) + if err != nil { + return nil, err + } + return &keyManagerCommandStreamClient{stream}, nil +} diff --git a/agent/internal/sync/sync.go b/agent/internal/sync/sync.go index 9402318..2f6df2d 100644 --- a/agent/internal/sync/sync.go +++ b/agent/internal/sync/sync.go @@ -1,6 +1,7 @@ package agentsync import ( + "context" "fmt" "log" "net" @@ -11,10 +12,11 @@ import ( "github.com/mrhid6/keymanager/agent/internal/config" grpcclient "github.com/mrhid6/keymanager/agent/internal/grpc" + "github.com/mrhid6/keymanager/agent/internal/grpc/pb" "github.com/mrhid6/keymanager/agent/internal/keys" ) -func Run(cfg *config.Config) error { +func Run(ctx context.Context, cfg *config.Config) error { client, err := grpcclient.New(cfg.ServerURL, cfg.TLS) if err != nil { return fmt.Errorf("dial grpc: %w", err) @@ -40,7 +42,6 @@ func Run(cfg *config.Config) error { } log.Println("registration successful") - // Reconnect with potentially updated state client.Close() client, err = grpcclient.New(cfg.ServerURL, cfg.TLS) if err != nil { @@ -52,6 +53,9 @@ func Run(cfg *config.Config) error { return fmt.Errorf("no agent token available — registration required") } + // Start the command stream alongside the poll loop. + go runCommandStream(ctx, cfg) + ticker := time.NewTicker(cfg.PollInterval) defer ticker.Stop() @@ -60,12 +64,16 @@ func Run(cfg *config.Config) error { log.Printf("poll error: %v", err) } - for range ticker.C { - if err := poll(client, cfg); err != nil { - log.Printf("poll error: %v", err) + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + if err := poll(client, cfg); err != nil { + log.Printf("poll error: %v", err) + } } } - return nil } func poll(client *grpcclient.Client, cfg *config.Config) error { @@ -91,6 +99,97 @@ func poll(client *grpcclient.Client, cfg *config.Config) error { return nil } +// runCommandStream maintains a persistent bidirectional stream with the server +// for instant command delivery. Reconnects with exponential backoff on failure. +func runCommandStream(ctx context.Context, cfg *config.Config) { + backoff := time.Second + const maxBackoff = 2 * time.Minute + + for { + select { + case <-ctx.Done(): + return + default: + } + + if err := connectAndHandleStream(ctx, cfg); err != nil { + if ctx.Err() != nil { + return + } + log.Printf("command stream error: %v, reconnecting in %s", err, backoff) + select { + case <-ctx.Done(): + return + case <-time.After(backoff): + } + if backoff < maxBackoff { + backoff *= 2 + } + } else { + backoff = time.Second + } + } +} + +func connectAndHandleStream(ctx context.Context, cfg *config.Config) error { + client, err := grpcclient.New(cfg.ServerURL, cfg.TLS) + if err != nil { + return fmt.Errorf("dial: %w", err) + } + defer client.Close() + + stream, err := client.CommandStream(ctx) + if err != nil { + return fmt.Errorf("open stream: %w", err) + } + + if err := stream.Send(&pb.AgentMessage{ + ServerId: cfg.ServerID, + AgentToken: cfg.AgentToken, + Ready: &pb.AgentReady{}, + }); err != nil { + return fmt.Errorf("send auth: %w", err) + } + + log.Println("command stream connected") + + for { + cmd, err := stream.Recv() + if err != nil { + return fmt.Errorf("recv: %w", err) + } + + if cmd.GenerateKey != nil { + go handleGenerateKey(cfg, cmd) + } + } +} + +func handleGenerateKey(cfg *config.Config, cmd *pb.ServerCommand) { + label := cmd.GenerateKey.Label + keyPath := fmt.Sprintf("/root/.ssh/keymanager_%s", strings.ReplaceAll(label, " ", "_")) + + pubKey, err := keys.GenerateKeyPair(keyPath, label) + if err != nil { + log.Printf("key generation failed (cmd=%s): %v", cmd.CommandId, err) + return + } + + client, err := grpcclient.New(cfg.ServerURL, cfg.TLS) + if err != nil { + log.Printf("dial for key upload failed (cmd=%s): %v", cmd.CommandId, err) + return + } + defer client.Close() + + keyID, err := client.UploadGeneratedKey(cfg.ServerID, cfg.AgentToken, pubKey, label) + if err != nil { + log.Printf("key upload failed (cmd=%s): %v", cmd.CommandId, err) + return + } + log.Printf("generated and uploaded key %q (key_id=%s, cmd=%s)", label, keyID, cmd.CommandId) +} + func localIP() string { addrs, err := net.InterfaceAddrs() if err != nil { diff --git a/proto/keymanager/v1/keymanager.proto b/proto/keymanager/v1/keymanager.proto index 3ac5ad3..296c802 100644 --- a/proto/keymanager/v1/keymanager.proto +++ b/proto/keymanager/v1/keymanager.proto @@ -5,9 +5,11 @@ package keymanager.v1; option go_package = "github.com/mrhid6/keymanager/server/internal/grpc/pb"; service KeyManager { - rpc Register(RegisterRequest) returns (RegisterResponse); - rpc SyncKeys(SyncRequest) returns (SyncResponse); - rpc UploadGeneratedKey(UploadKeyRequest) returns (UploadKeyResponse); + rpc Register(RegisterRequest) returns (RegisterResponse); + rpc SyncKeys(SyncRequest) returns (SyncResponse); + rpc UploadGeneratedKey(UploadKeyRequest) returns (UploadKeyResponse); + // Bidirectional stream: agent sends auth once, server pushes commands. + rpc CommandStream(stream AgentMessage) returns (stream ServerCommand); } message RegisterRequest { @@ -41,3 +43,33 @@ message UploadKeyRequest { message UploadKeyResponse { string key_id = 1; } + +// CommandStream messages + +message AgentMessage { + string server_id = 1; + string agent_token = 2; + oneof payload { + AgentReady ready = 3; + CommandResult result = 4; + } +} + +message AgentReady {} + +message CommandResult { + string command_id = 1; + bool success = 2; + string message = 3; +} + +message ServerCommand { + string command_id = 1; + oneof command { + GenerateKeyCmd generate_key = 2; + } +} + +message GenerateKeyCmd { + string label = 1; +} diff --git a/server/internal/api/handlers.go b/server/internal/api/handlers.go index 7c740c8..8dcc526 100644 --- a/server/internal/api/handlers.go +++ b/server/internal/api/handlers.go @@ -122,18 +122,32 @@ func deleteServer(c *gin.Context) { } func generateKey(c *gin.Context) { - // The agent triggers key generation itself; this endpoint signals - // the intent by returning the server so the caller knows to wait - // for the agent to upload via gRPC UploadGeneratedKey. id := c.Param("id") + + var body struct { + Label string `json:"label"` + } + _ = c.ShouldBindJSON(&body) + if body.Label == "" { + body.Label = "generated" + } + s, err := services.GetServer(id) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "server not found"}) return } - c.JSON(http.StatusOK, gin.H{ - "message": "agent will generate and upload key on next poll", - "server_id": s.ServerID, + + cmdID, err := services.DispatchGenerateKey(s.ServerID, body.Label) + if err != nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusAccepted, gin.H{ + "message": "key generation command sent to agent", + "command_id": cmdID, + "server_id": s.ServerID, }) } diff --git a/server/internal/grpc/pb/keymanager.pb.go b/server/internal/grpc/pb/keymanager.pb.go index 663385c..56e3cb3 100644 --- a/server/internal/grpc/pb/keymanager.pb.go +++ b/server/internal/grpc/pb/keymanager.pb.go @@ -45,12 +45,87 @@ type UploadKeyResponse struct { KeyId string `json:"key_id"` } +// CommandStream message types + +type ServerCommand struct { + CommandId string `json:"command_id"` + GenerateKey *GenerateKeyCmd `json:"generate_key,omitempty"` +} + +type GenerateKeyCmd struct { + Label string `json:"label"` +} + +type AgentMessage struct { + ServerId string `json:"server_id"` + AgentToken string `json:"agent_token"` + Ready *AgentReady `json:"ready,omitempty"` + Result *CommandResult `json:"result,omitempty"` +} + +type AgentReady struct{} + +type CommandResult struct { + CommandId string `json:"command_id"` + Success bool `json:"success"` + Message string `json:"message"` +} + +// CommandStream server-side interface + +type KeyManager_CommandStreamServer interface { + Send(*ServerCommand) error + Recv() (*AgentMessage, error) + grpc.ServerStream +} + +type keyManagerCommandStreamServer struct { + grpc.ServerStream +} + +func (s *keyManagerCommandStreamServer) Send(m *ServerCommand) error { + return s.ServerStream.SendMsg(m) +} + +func (s *keyManagerCommandStreamServer) Recv() (*AgentMessage, error) { + m := new(AgentMessage) + if err := s.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// CommandStream client-side interface + +type KeyManager_CommandStreamClient interface { + Send(*AgentMessage) error + Recv() (*ServerCommand, error) + grpc.ClientStream +} + +type keyManagerCommandStreamClient struct { + grpc.ClientStream +} + +func (c *keyManagerCommandStreamClient) Send(m *AgentMessage) error { + return c.ClientStream.SendMsg(m) +} + +func (c *keyManagerCommandStreamClient) Recv() (*ServerCommand, error) { + m := new(ServerCommand) + if err := c.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + // Server interface type KeyManagerServer interface { Register(context.Context, *RegisterRequest) (*RegisterResponse, error) SyncKeys(context.Context, *SyncRequest) (*SyncResponse, error) UploadGeneratedKey(context.Context, *UploadKeyRequest) (*UploadKeyResponse, error) + CommandStream(KeyManager_CommandStreamServer) error } type UnimplementedKeyManagerServer struct{} @@ -67,12 +142,17 @@ func (UnimplementedKeyManagerServer) UploadGeneratedKey(context.Context, *Upload return nil, status.Errorf(codes.Unimplemented, "method UploadGeneratedKey not implemented") } +func (UnimplementedKeyManagerServer) CommandStream(KeyManager_CommandStreamServer) error { + return status.Errorf(codes.Unimplemented, "method CommandStream not implemented") +} + // Client interface type KeyManagerClient interface { Register(ctx context.Context, in *RegisterRequest, opts ...grpc.CallOption) (*RegisterResponse, error) SyncKeys(ctx context.Context, in *SyncRequest, opts ...grpc.CallOption) (*SyncResponse, error) UploadGeneratedKey(ctx context.Context, in *UploadKeyRequest, opts ...grpc.CallOption) (*UploadKeyResponse, error) + CommandStream(ctx context.Context, opts ...grpc.CallOption) (KeyManager_CommandStreamClient, error) } type keyManagerClient struct { @@ -107,6 +187,14 @@ func (c *keyManagerClient) UploadGeneratedKey(ctx context.Context, in *UploadKey return out, nil } +func (c *keyManagerClient) CommandStream(ctx context.Context, opts ...grpc.CallOption) (KeyManager_CommandStreamClient, error) { + stream, err := c.cc.NewStream(ctx, &KeyManager_ServiceDesc.Streams[0], "/keymanager.v1.KeyManager/CommandStream", opts...) + if err != nil { + return nil, err + } + return &keyManagerCommandStreamClient{stream}, nil +} + // Server registration func RegisterKeyManagerServer(s grpc.ServiceRegistrar, srv KeyManagerServer) { @@ -121,7 +209,14 @@ var KeyManager_ServiceDesc = grpc.ServiceDesc{ {MethodName: "SyncKeys", Handler: _KeyManager_SyncKeys_Handler}, {MethodName: "UploadGeneratedKey", Handler: _KeyManager_UploadGeneratedKey_Handler}, }, - Streams: []grpc.StreamDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "CommandStream", + Handler: _KeyManager_CommandStream_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, Metadata: "keymanager/v1/keymanager.proto", } @@ -169,3 +264,7 @@ func _KeyManager_UploadGeneratedKey_Handler(srv interface{}, ctx context.Context } return interceptor(ctx, in, info, handler) } + +func _KeyManager_CommandStream_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(KeyManagerServer).CommandStream(&keyManagerCommandStreamServer{stream}) +} diff --git a/server/internal/grpc/server.go b/server/internal/grpc/server.go index 9e6beb2..dd9da66 100644 --- a/server/internal/grpc/server.go +++ b/server/internal/grpc/server.go @@ -67,6 +67,59 @@ func (s *keyManagerServer) UploadGeneratedKey(ctx context.Context, req *pb.Uploa return &pb.UploadKeyResponse{KeyId: key.KeyID}, nil } +func (s *keyManagerServer) CommandStream(stream pb.KeyManager_CommandStreamServer) error { + // First message authenticates the agent and signals readiness. + msg, err := stream.Recv() + if err != nil { + return status.Errorf(codes.InvalidArgument, "expected initial auth message: %v", err) + } + + srv, err := services.ValidateAgentToken(msg.ServerId, msg.AgentToken) + if err != nil { + return status.Errorf(codes.Unauthenticated, "invalid agent token") + } + + if err := services.UpdateServerLastSeen(srv.ServerID); err != nil { + log.Printf("update last seen %s: %v", srv.ServerID, err) + } + + ch := services.Dispatcher.Connect(srv.ServerID) + defer services.Dispatcher.Disconnect(srv.ServerID) + + log.Printf("agent %s connected command stream", srv.ServerID) + defer log.Printf("agent %s disconnected command stream", srv.ServerID) + + // Drain inbound results in the background so client Send calls never block. + // UploadGeneratedKey handles the real storage; these are just confirmation logs. + go func() { + for { + m, err := stream.Recv() + if err != nil { + return + } + if m.Result != nil { + r := m.Result + log.Printf("agent %s cmd %s: success=%v %s", srv.ServerID, r.CommandId, r.Success, r.Message) + } + } + }() + + ctx := stream.Context() + for { + select { + case <-ctx.Done(): + return nil + case cmd, ok := <-ch: + if !ok { + return nil + } + if err := stream.Send(cmd); err != nil { + return err + } + } + } +} + func StartGRPC(port int) error { lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { diff --git a/server/internal/services/dispatch.go b/server/internal/services/dispatch.go new file mode 100644 index 0000000..bb83db8 --- /dev/null +++ b/server/internal/services/dispatch.go @@ -0,0 +1,76 @@ +package services + +import ( + "fmt" + "sync" + + "github.com/google/uuid" + "github.com/mrhid6/keymanager/server/internal/grpc/pb" +) + +type commandDispatcher struct { + mu sync.RWMutex + channels map[string]chan *pb.ServerCommand +} + +// Dispatcher is the singleton command dispatcher used by both the gRPC server +// and the REST API to push commands to connected agents. +var Dispatcher = &commandDispatcher{ + channels: make(map[string]chan *pb.ServerCommand), +} + +// Connect registers an agent's command channel. Returns the channel to drain. +func (d *commandDispatcher) Connect(serverID string) chan *pb.ServerCommand { + ch := make(chan *pb.ServerCommand, 16) + d.mu.Lock() + d.channels[serverID] = ch + d.mu.Unlock() + return ch +} + +// Disconnect removes the agent's channel on stream close. +func (d *commandDispatcher) Disconnect(serverID string) { + d.mu.Lock() + delete(d.channels, serverID) + d.mu.Unlock() +} + +// IsConnected reports whether an agent is currently holding a CommandStream. +func (d *commandDispatcher) IsConnected(serverID string) bool { + d.mu.RLock() + _, ok := d.channels[serverID] + d.mu.RUnlock() + return ok +} + +func (d *commandDispatcher) dispatch(serverID string, cmd *pb.ServerCommand) error { + d.mu.RLock() + ch, ok := d.channels[serverID] + d.mu.RUnlock() + if !ok { + return fmt.Errorf("agent for server %s is not connected", serverID) + } + select { + case ch <- cmd: + return nil + default: + return fmt.Errorf("command queue full for server %s", serverID) + } +} + +// DispatchGenerateKey sends a generate-key command to the named server's agent. +// Returns the command ID that can be used to correlate the agent's result. +func DispatchGenerateKey(serverID, label string) (string, error) { + if !Dispatcher.IsConnected(serverID) { + return "", fmt.Errorf("agent is not connected to the command stream") + } + cmdID := uuid.New().String() + cmd := &pb.ServerCommand{ + CommandId: cmdID, + GenerateKey: &pb.GenerateKeyCmd{Label: label}, + } + if err := Dispatcher.dispatch(serverID, cmd); err != nil { + return "", err + } + return cmdID, nil +}