135 lines
3.8 KiB
Go
135 lines
3.8 KiB
Go
package grpcserver
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
|
|
"github.com/mrhid6/keymanager/server/internal/grpc/pb"
|
|
"github.com/mrhid6/keymanager/server/internal/services"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/encoding"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
func init() {
|
|
encoding.RegisterCodec(JSONCodec{})
|
|
}
|
|
|
|
type keyManagerServer struct {
|
|
pb.UnimplementedKeyManagerServer
|
|
}
|
|
|
|
func (s *keyManagerServer) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
|
|
agentToken, err := services.RegisterServer(req.ServerId, req.PreRegToken, req.Hostname, req.IpAddress, req.OsInfo)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "registration failed: %v", err)
|
|
}
|
|
return &pb.RegisterResponse{AgentToken: agentToken}, nil
|
|
}
|
|
|
|
func (s *keyManagerServer) SyncKeys(ctx context.Context, req *pb.SyncRequest) (*pb.SyncResponse, error) {
|
|
srv, err := services.ValidateAgentToken(req.ServerId, req.AgentToken)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Unauthenticated, "invalid agent token")
|
|
}
|
|
|
|
if err := services.UpdateServerLastSeen(srv.ServerID); err != nil {
|
|
log.Printf("failed to update last seen for %s: %v", srv.ServerID, err)
|
|
}
|
|
|
|
keys, err := services.BuildAuthorizedKeys(req.ServerId)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "failed to build authorized keys: %v", err)
|
|
}
|
|
|
|
return &pb.SyncResponse{PublicKeys: keys}, nil
|
|
}
|
|
|
|
func (s *keyManagerServer) UploadGeneratedKey(ctx context.Context, req *pb.UploadKeyRequest) (*pb.UploadKeyResponse, error) {
|
|
srv, err := services.ValidateAgentToken(req.ServerId, req.AgentToken)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Unauthenticated, "invalid agent token")
|
|
}
|
|
|
|
key, err := services.CreateKey(req.Label, req.PublicKey, "generated", srv.ServerID)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "failed to store key: %v", err)
|
|
}
|
|
|
|
// Auto-assign to the generating server
|
|
if _, err := services.AssignKey(key.KeyID, srv.ServerID); err != nil {
|
|
log.Printf("failed to auto-assign generated key: %v", err)
|
|
}
|
|
|
|
return &pb.UploadKeyResponse{KeyId: key.KeyID}, nil
|
|
}
|
|
|
|
func (s *keyManagerServer) CommandStream(stream pb.KeyManager_CommandStreamServer) error {
|
|
// First message authenticates the agent and signals readiness.
|
|
msg, err := stream.Recv()
|
|
if err != nil {
|
|
return status.Errorf(codes.InvalidArgument, "expected initial auth message: %v", err)
|
|
}
|
|
|
|
srv, err := services.ValidateAgentToken(msg.ServerId, msg.AgentToken)
|
|
if err != nil {
|
|
return status.Errorf(codes.Unauthenticated, "invalid agent token")
|
|
}
|
|
|
|
if err := services.UpdateServerLastSeen(srv.ServerID); err != nil {
|
|
log.Printf("update last seen %s: %v", srv.ServerID, err)
|
|
}
|
|
|
|
ch := services.Dispatcher.Connect(srv.ServerID)
|
|
defer services.Dispatcher.Disconnect(srv.ServerID)
|
|
|
|
log.Printf("agent %s connected command stream", srv.ServerID)
|
|
defer log.Printf("agent %s disconnected command stream", srv.ServerID)
|
|
|
|
// Drain inbound results in the background so client Send calls never block.
|
|
// UploadGeneratedKey handles the real storage; these are just confirmation logs.
|
|
go func() {
|
|
for {
|
|
m, err := stream.Recv()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if m.Result != nil {
|
|
r := m.Result
|
|
log.Printf("agent %s cmd %s: success=%v %s", srv.ServerID, r.CommandId, r.Success, r.Message)
|
|
}
|
|
}
|
|
}()
|
|
|
|
ctx := stream.Context()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case cmd, ok := <-ch:
|
|
if !ok {
|
|
return nil
|
|
}
|
|
if err := stream.Send(cmd); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func StartGRPC(port int) error {
|
|
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen: %w", err)
|
|
}
|
|
|
|
s := grpc.NewServer()
|
|
pb.RegisterKeyManagerServer(s, &keyManagerServer{})
|
|
|
|
log.Printf("gRPC server listening on :%d", port)
|
|
return s.Serve(lis)
|
|
}
|