updates
This commit is contained in:
@@ -102,3 +102,9 @@ func (c *Client) UploadGeneratedKey(serverID, agentToken, publicKey, label strin
|
||||
}
|
||||
return resp.KeyId, nil
|
||||
}
|
||||
|
||||
// CommandStream opens a long-lived bidirectional stream for server-pushed commands.
|
||||
// The caller controls the stream lifetime via ctx.
|
||||
func (c *Client) CommandStream(ctx context.Context) (pb.KeyManager_CommandStreamClient, error) {
|
||||
return c.client.CommandStream(ctx)
|
||||
}
|
||||
|
||||
@@ -42,10 +42,85 @@ type UploadKeyResponse struct {
|
||||
KeyId string `json:"key_id"`
|
||||
}
|
||||
|
||||
// CommandStream message types
|
||||
|
||||
type ServerCommand struct {
|
||||
CommandId string `json:"command_id"`
|
||||
GenerateKey *GenerateKeyCmd `json:"generate_key,omitempty"`
|
||||
}
|
||||
|
||||
type GenerateKeyCmd struct {
|
||||
Label string `json:"label"`
|
||||
}
|
||||
|
||||
type AgentMessage struct {
|
||||
ServerId string `json:"server_id"`
|
||||
AgentToken string `json:"agent_token"`
|
||||
Ready *AgentReady `json:"ready,omitempty"`
|
||||
Result *CommandResult `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
type AgentReady struct{}
|
||||
|
||||
type CommandResult struct {
|
||||
CommandId string `json:"command_id"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// CommandStream client-side interface
|
||||
|
||||
type KeyManager_CommandStreamClient interface {
|
||||
Send(*AgentMessage) error
|
||||
Recv() (*ServerCommand, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type keyManagerCommandStreamClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (c *keyManagerCommandStreamClient) Send(m *AgentMessage) error {
|
||||
return c.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (c *keyManagerCommandStreamClient) Recv() (*ServerCommand, error) {
|
||||
m := new(ServerCommand)
|
||||
if err := c.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// CommandStream server-side interface (included for completeness)
|
||||
|
||||
type KeyManager_CommandStreamServer interface {
|
||||
Send(*ServerCommand) error
|
||||
Recv() (*AgentMessage, error)
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type keyManagerCommandStreamServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (s *keyManagerCommandStreamServer) Send(m *ServerCommand) error {
|
||||
return s.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (s *keyManagerCommandStreamServer) Recv() (*AgentMessage, error) {
|
||||
m := new(AgentMessage)
|
||||
if err := s.ServerStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type KeyManagerClient interface {
|
||||
Register(ctx context.Context, in *RegisterRequest, opts ...grpc.CallOption) (*RegisterResponse, error)
|
||||
SyncKeys(ctx context.Context, in *SyncRequest, opts ...grpc.CallOption) (*SyncResponse, error)
|
||||
UploadGeneratedKey(ctx context.Context, in *UploadKeyRequest, opts ...grpc.CallOption) (*UploadKeyResponse, error)
|
||||
CommandStream(ctx context.Context, opts ...grpc.CallOption) (KeyManager_CommandStreamClient, error)
|
||||
}
|
||||
|
||||
type UnimplementedKeyManagerServer struct{}
|
||||
@@ -91,3 +166,12 @@ func (c *keyManagerClient) UploadGeneratedKey(ctx context.Context, in *UploadKey
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *keyManagerClient) CommandStream(ctx context.Context, opts ...grpc.CallOption) (KeyManager_CommandStreamClient, error) {
|
||||
desc := &grpc.StreamDesc{StreamName: "CommandStream", ServerStreams: true, ClientStreams: true}
|
||||
stream, err := c.cc.NewStream(ctx, desc, "/keymanager.v1.KeyManager/CommandStream", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &keyManagerCommandStreamClient{stream}, nil
|
||||
}
|
||||
|
||||
+105
-6
@@ -1,6 +1,7 @@
|
||||
package agentsync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
@@ -11,10 +12,11 @@ import (
|
||||
|
||||
"github.com/mrhid6/keymanager/agent/internal/config"
|
||||
grpcclient "github.com/mrhid6/keymanager/agent/internal/grpc"
|
||||
"github.com/mrhid6/keymanager/agent/internal/grpc/pb"
|
||||
"github.com/mrhid6/keymanager/agent/internal/keys"
|
||||
)
|
||||
|
||||
func Run(cfg *config.Config) error {
|
||||
func Run(ctx context.Context, cfg *config.Config) error {
|
||||
client, err := grpcclient.New(cfg.ServerURL, cfg.TLS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial grpc: %w", err)
|
||||
@@ -40,7 +42,6 @@ func Run(cfg *config.Config) error {
|
||||
}
|
||||
log.Println("registration successful")
|
||||
|
||||
// Reconnect with potentially updated state
|
||||
client.Close()
|
||||
client, err = grpcclient.New(cfg.ServerURL, cfg.TLS)
|
||||
if err != nil {
|
||||
@@ -52,6 +53,9 @@ func Run(cfg *config.Config) error {
|
||||
return fmt.Errorf("no agent token available — registration required")
|
||||
}
|
||||
|
||||
// Start the command stream alongside the poll loop.
|
||||
go runCommandStream(ctx, cfg)
|
||||
|
||||
ticker := time.NewTicker(cfg.PollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -60,12 +64,16 @@ func Run(cfg *config.Config) error {
|
||||
log.Printf("poll error: %v", err)
|
||||
}
|
||||
|
||||
for range ticker.C {
|
||||
if err := poll(client, cfg); err != nil {
|
||||
log.Printf("poll error: %v", err)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
if err := poll(client, cfg); err != nil {
|
||||
log.Printf("poll error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func poll(client *grpcclient.Client, cfg *config.Config) error {
|
||||
@@ -91,6 +99,97 @@ func poll(client *grpcclient.Client, cfg *config.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// runCommandStream maintains a persistent bidirectional stream with the server
|
||||
// for instant command delivery. Reconnects with exponential backoff on failure.
|
||||
func runCommandStream(ctx context.Context, cfg *config.Config) {
|
||||
backoff := time.Second
|
||||
const maxBackoff = 2 * time.Minute
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err := connectAndHandleStream(ctx, cfg); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Printf("command stream error: %v, reconnecting in %s", err, backoff)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
if backoff < maxBackoff {
|
||||
backoff *= 2
|
||||
}
|
||||
} else {
|
||||
backoff = time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func connectAndHandleStream(ctx context.Context, cfg *config.Config) error {
|
||||
client, err := grpcclient.New(cfg.ServerURL, cfg.TLS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
stream, err := client.CommandStream(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open stream: %w", err)
|
||||
}
|
||||
|
||||
if err := stream.Send(&pb.AgentMessage{
|
||||
ServerId: cfg.ServerID,
|
||||
AgentToken: cfg.AgentToken,
|
||||
Ready: &pb.AgentReady{},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send auth: %w", err)
|
||||
}
|
||||
|
||||
log.Println("command stream connected")
|
||||
|
||||
for {
|
||||
cmd, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("recv: %w", err)
|
||||
}
|
||||
|
||||
if cmd.GenerateKey != nil {
|
||||
go handleGenerateKey(cfg, cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleGenerateKey(cfg *config.Config, cmd *pb.ServerCommand) {
|
||||
label := cmd.GenerateKey.Label
|
||||
keyPath := fmt.Sprintf("/root/.ssh/keymanager_%s", strings.ReplaceAll(label, " ", "_"))
|
||||
|
||||
pubKey, err := keys.GenerateKeyPair(keyPath, label)
|
||||
if err != nil {
|
||||
log.Printf("key generation failed (cmd=%s): %v", cmd.CommandId, err)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := grpcclient.New(cfg.ServerURL, cfg.TLS)
|
||||
if err != nil {
|
||||
log.Printf("dial for key upload failed (cmd=%s): %v", cmd.CommandId, err)
|
||||
return
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
keyID, err := client.UploadGeneratedKey(cfg.ServerID, cfg.AgentToken, pubKey, label)
|
||||
if err != nil {
|
||||
log.Printf("key upload failed (cmd=%s): %v", cmd.CommandId, err)
|
||||
return
|
||||
}
|
||||
log.Printf("generated and uploaded key %q (key_id=%s, cmd=%s)", label, keyID, cmd.CommandId)
|
||||
}
|
||||
|
||||
func localIP() string {
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user