From 3d7bef3a82075b540c99f87d16594587faac48b7 Mon Sep 17 00:00:00 2001 From: restitux Date: Mon, 9 Jan 2023 00:59:19 -0700 Subject: [PATCH] Implment worker api (fixes #3) --- go.mod | 5 +- go.sum | 6 +- pipeline_api/pipeline_api.go | 57 +++++++++++++++---- runnermanager/runner.go | 78 +++++++++++++++++++++++-- runnermanager/runnermanager.go | 101 +++++++++++++++++++++++++-------- 5 files changed, 202 insertions(+), 45 deletions(-) diff --git a/go.mod b/go.mod index 1ff7d33..5252fdb 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,8 @@ module git.ohea.xyz/cursorius/server go 1.19 require ( - git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230102084147-6d988866458e + git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230109075652-ead0aeff2eb9 + git.ohea.xyz/cursorius/runner-api/go/api/v2 v2.0.0-20230109074922-e20285fe6cf2 git.ohea.xyz/golang/config v0.0.0-20220915224621-b9debd233173 github.com/bufbuild/connect-go v1.4.1 github.com/docker/docker v20.10.22+incompatible @@ -12,6 +13,7 @@ require ( github.com/google/uuid v1.3.0 github.com/op/go-logging v0.0.0-20160315200505-970db520ece7 golang.org/x/net v0.2.0 + google.golang.org/protobuf v1.28.1 nhooyr.io/websocket v1.8.7 ) @@ -50,7 +52,6 @@ require ( golang.org/x/sys v0.2.0 // indirect golang.org/x/text v0.4.0 // indirect golang.org/x/tools v0.1.12 // indirect - google.golang.org/protobuf v1.28.1 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gotest.tools/v3 v3.4.0 // indirect diff --git a/go.sum b/go.sum index 37e4bca..3a66849 100644 --- a/go.sum +++ b/go.sum @@ -78,8 +78,10 @@ contrib.go.opencensus.io/exporter/stackdriver v0.13.5/go.mod h1:aXENhDJ1Y4lIg4EU contrib.go.opencensus.io/integrations/ocsql v0.1.4/go.mod h1:8DsSdjz3F+APR+0z0WkU1aRorQCFfRxvqjUUPMbF3fE= contrib.go.opencensus.io/resource v0.1.1/go.mod h1:F361eGI91LCmW1I/Saf+rX0+OFcigGlFvXwEGEnkRLA= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230102084147-6d988866458e h1:Rdxx7d/iYk+nazy0/Z0Pp1f9XKiJ6b/3nyiw3MXOfHg= -git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230102084147-6d988866458e/go.mod h1:D7GGcFIH421mo6KuRaXXXmlXPwWeEsemTZG/BOZA/4o= +git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230109075652-ead0aeff2eb9 h1:8p7Kw3B7dbi2zdgG+Me9ETRWrJzoNVjcase4YqXfGbs= +git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230109075652-ead0aeff2eb9/go.mod h1:D7GGcFIH421mo6KuRaXXXmlXPwWeEsemTZG/BOZA/4o= +git.ohea.xyz/cursorius/runner-api/go/api/v2 v2.0.0-20230109074922-e20285fe6cf2 h1:G1XQEqhj1LZPQbH7avzvT7QL9Wfbb4CXMm0nLL39eDc= +git.ohea.xyz/cursorius/runner-api/go/api/v2 v2.0.0-20230109074922-e20285fe6cf2/go.mod h1:F9y5Ck4Wchsaj5amSX2eDRUlQ/iYP1VNLFduvjNwmLc= git.ohea.xyz/cursorius/webhooks/v6 v6.0.2-0.20221224221147-a2bdbf1756ed h1:gsK15m4Npow74+R6OfZKwwAg1sl7QWQCRXOeE2QLUco= git.ohea.xyz/cursorius/webhooks/v6 v6.0.2-0.20221224221147-a2bdbf1756ed/go.mod h1:64JKTmG3kupV+3+ZYJYPB/rGPEKw/diihhIj8lut4UA= git.ohea.xyz/golang/config v0.0.0-20220915224621-b9debd233173 h1:dhq/W6sa5KkLHVBwwgcNIPWcO4YK2/ecFTTln2W+1n8= diff --git a/pipeline_api/pipeline_api.go b/pipeline_api/pipeline_api.go index efe59dd..3c7490b 100644 --- a/pipeline_api/pipeline_api.go +++ b/pipeline_api/pipeline_api.go @@ -24,25 +24,32 @@ type ApiServer struct { } type RunnerWrapper struct { - runner runnermanager.Runner + runner *runnermanager.Runner mutex sync.Mutex } -func (r *RunnerWrapper) RunCommand(cmd string) (int64, string, string, error) { - r.mutex.Unlock() - defer r.mutex.Lock() +func (r *RunnerWrapper) RunCommand(cmd string, args []string) (int64, string, string, error) { + r.mutex.Lock() + defer r.mutex.Unlock() - return_code, stdout, stderr, err := r.runner.RunCommand(cmd) + return_code, stdout, stderr, err := r.runner.RunCommand(cmd, args) // TODO: run command by sending websocket packet // TODO: get stdout and stderr response return return_code, stdout, stderr, err } -func (s *ApiServer) GetRunnerFromMap(u uuid.UUID) *RunnerWrapper { +func (r *RunnerWrapper) Release() { + r.mutex.Lock() + defer r.mutex.Unlock() + r.runner.Release() +} + +func (s *ApiServer) GetRunnerFromMap(u uuid.UUID) (*RunnerWrapper, bool) { s.allocatedRunnersMutex.RLock() defer s.allocatedRunnersMutex.RUnlock() - return s.allocatedRunners[u] + runner, ok := s.allocatedRunners[u] + return runner, ok } func (s *ApiServer) GetRunner( @@ -69,7 +76,7 @@ func (s *ApiServer) GetRunner( return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("Could not get runner")) } - log.Info("Got runner with tags: %v", runnerTagsStr) + log.Infof("Got runner with tags: %v", runnerTagsStr.String()) runnerUuid := uuid.New() @@ -84,6 +91,29 @@ func (s *ApiServer) GetRunner( return res, nil } +func (s *ApiServer) ReleaseRunner( + ctx context.Context, + req *connect.Request[apiv2.ReleaseRunnerRequest], +) (*connect.Response[apiv2.ReleaseRunnerResponse], error) { + uuid, err := uuid.Parse(req.Msg.Id) + if err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("Invalid runner id")) + } + + log.Infof("Releasing runner with ID \"%v\"", uuid) + + s.allocatedRunnersMutex.Lock() + runner := s.allocatedRunners[uuid] + delete(s.allocatedRunners, uuid) + runner.Release() + s.allocatedRunnersMutex.Unlock() + + res := connect.NewResponse(&apiv2.ReleaseRunnerResponse{}) + res.Header().Set("ReleaseRunner-Version", "v2") + return res, nil + +} + func (s *ApiServer) RunCommand( ctx context.Context, req *connect.Request[apiv2.RunCommandRequest], @@ -94,16 +124,19 @@ func (s *ApiServer) RunCommand( return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("Invalid runner id")) } - runner := s.GetRunnerFromMap(uuid) + runner, ok := s.GetRunnerFromMap(uuid) + if !ok { + return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("Invalid runner id")) + } - return_code, stdout, stderr, err := runner.RunCommand(req.Msg.Command) + returnCode, stdout, stderr, err := runner.RunCommand(req.Msg.Command, req.Msg.Args) if err != nil { log.Errorf("Could not run command on runner \"%v\", %v", runner.runner.Id(), err) return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("Could not run command")) } res := connect.NewResponse(&apiv2.RunCommandResponse{ - ReturnCode: return_code, + ReturnCode: returnCode, Stdout: stdout, Stderr: stderr, }) @@ -118,6 +151,8 @@ func CreateHandler(mux *http.ServeMux, getRunnerCh chan runnermanager.GetRunnerR } path, handler := apiv2connect.NewGetRunnerServiceHandler(api_server) mux.Handle(path, handler) + path, handler = apiv2connect.NewReleaseRunnerServiceHandler(api_server) + mux.Handle(path, handler) path, handler = apiv2connect.NewRunCommandServiceHandler(api_server) mux.Handle(path, handler) } diff --git a/runnermanager/runner.go b/runnermanager/runner.go index e9676bc..1b76f04 100644 --- a/runnermanager/runner.go +++ b/runnermanager/runner.go @@ -1,8 +1,15 @@ package runnermanager -import "nhooyr.io/websocket" +import ( + "context" + "fmt" -//var log = logging.MustGetLogger("cursorius-server") + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "nhooyr.io/websocket" + + runner_api "git.ohea.xyz/cursorius/runner-api/go/api/v2" +) type RunnerData struct { msgType websocket.MessageType @@ -13,7 +20,7 @@ type Runner struct { id string tags []string conn *websocket.Conn - receiveChan chan RunnerData + receiveChan chan []byte running bool } @@ -21,6 +28,67 @@ func (r *Runner) Id() string { return r.id } -func (r *Runner) RunCommand(string) (int64, string, string, error) { - return 0, "", "", nil +func (r *Runner) Release() { + r.running = false +} + +func (r *Runner) RunCommand(cmd string, args []string) (returnCode int64, stdout string, stderr string, err error) { + + // Write RunCommand message to client + serverToRunnerMsg := &runner_api.ServerToRunnerMsg{ + Msg: &runner_api.ServerToRunnerMsg_RunCommandMsg{ + RunCommandMsg: &runner_api.RunCommand{ + Command: cmd, + Args: args, + }, + }, + } + + err = r.sendProtoStruct(serverToRunnerMsg) + if err != nil { + err = fmt.Errorf("Could not send command to client: %w", err) + return + } + + for { + // Read RunCommandFinalResponse message from client + data, ok := <-r.receiveChan + if !ok { + err = fmt.Errorf("Channel is closed on runner") + return + } + + runnerToServerMsg := &runner_api.RunnerToServerMsg{} + if err = proto.Unmarshal(data, runnerToServerMsg); err != nil { + err = fmt.Errorf("Could not parse RunCommand response: %w", err) + r.conn.Close(websocket.StatusUnsupportedData, "Invalid message") + return + } + + switch x := runnerToServerMsg.Msg.(type) { + case *runner_api.RunnerToServerMsg_RunCommandPartialResponseMsg: + stdout += x.RunCommandPartialResponseMsg.Stdout + stderr += x.RunCommandPartialResponseMsg.Stderr + case *runner_api.RunnerToServerMsg_RunCommandFinalResponseMsg: + stdout += x.RunCommandFinalResponseMsg.PartialResponse.Stdout + stderr += x.RunCommandFinalResponseMsg.PartialResponse.Stderr + returnCode = x.RunCommandFinalResponseMsg.ReturnCode + return + } + } +} + +func (r *Runner) sendProtoStruct(p protoreflect.ProtoMessage) error { + protoOut, err := proto.Marshal(p) + if err != nil { + return fmt.Errorf("Could not marshal proto: %w", err) + } + + ctx := context.Background() + + if err := r.conn.Write(ctx, websocket.MessageBinary, protoOut); err != nil { + return fmt.Errorf("Could not send proto to websocket: %w", err) + } + + return nil } diff --git a/runnermanager/runnermanager.go b/runnermanager/runnermanager.go index 9d014b3..ac570d0 100644 --- a/runnermanager/runnermanager.go +++ b/runnermanager/runnermanager.go @@ -6,10 +6,13 @@ import ( "strings" "time" - "git.ohea.xyz/cursorius/server/config" "github.com/op/go-logging" + "google.golang.org/protobuf/proto" "nhooyr.io/websocket" - "nhooyr.io/websocket/wsjson" + + "git.ohea.xyz/cursorius/server/config" + + runner_api "git.ohea.xyz/cursorius/runner-api/go/api/v2" ) var log = logging.MustGetLogger("cursorius-server") @@ -22,10 +25,11 @@ type RunnerRegistration struct { } type runnerManager struct { - getRunnerCh chan GetRunnerRequest - registerCh chan RunnerRegistration - connectedRunners []Runner - configuredRunners map[string]config.Runner + getRunnerCh chan GetRunnerRequest + registerCh chan RunnerRegistration + connectedRunners []Runner + numConnectedRunners uint64 + configuredRunners map[string]config.Runner } type GetRunnerRequest struct { @@ -34,7 +38,7 @@ type GetRunnerRequest struct { } type GetRunnerResponse struct { - Runner Runner + Runner *Runner Err error } @@ -54,6 +58,10 @@ func (r *runnerManager) processRequest(req GetRunnerRequest) { log.Debugf("Finding runner with tags %v", runnerTagsStr.String()) + foundRunner := false + + runnersToRemove := []int{} +runnerIter: for i, runner := range r.connectedRunners { // don't allocate runner that is already occupied if runner.running { @@ -72,25 +80,45 @@ func (r *runnerManager) processRequest(req GetRunnerRequest) { continue } log.Noticef("Removing defunct runner \"%v\"", runner.id) - // if the receive channel is closed, swap delete the runner as it's defunct - r.connectedRunners[i] = r.connectedRunners[len(r.connectedRunners)-1] - r.connectedRunners = r.connectedRunners[:len(r.connectedRunners)-1] + runnersToRemove = append(runnersToRemove, i) default: - runner.running = true + log.Debugf("Checking runner %v for requested tags", runner.id) + + tagIter: + for _, requestedTag := range req.Tags { + for _, posessedTag := range runner.tags { + if requestedTag == posessedTag { + continue tagIter + } + } + continue runnerIter + } + + r.connectedRunners[i].running = true + foundRunner = true req.RespChan <- GetRunnerResponse{ - Runner: runner, + Runner: &r.connectedRunners[i], Err: nil, } - return } + } + // since we iterate, all the indexes will be in accending order + for i, runnerInd := range runnersToRemove { + r.connectedRunners[runnerInd-i] = r.connectedRunners[len(r.connectedRunners)-1] + r.connectedRunners = r.connectedRunners[0 : len(r.connectedRunners)-2] + } + + if foundRunner { + return + } + errorMsg := "could not find valid runner" if len(r.connectedRunners) == 0 { errorMsg = "no connected runners" } - log.Errorf("Could not allocate runner with tags \"%v\": %v", runnerTagsStr.String(), errorMsg) req.RespChan <- GetRunnerResponse{ - Runner: Runner{}, + Runner: &Runner{}, Err: fmt.Errorf("Could not allocate runner: %v", errorMsg), } @@ -105,27 +133,32 @@ func (r *runnerManager) processRegistration(reg RunnerRegistration) { id: reg.Id, tags: reg.Tags, conn: reg.conn, - receiveChan: make(chan RunnerData), + receiveChan: make(chan []byte), running: false, } r.connectedRunners = append(r.connectedRunners, runner) // start goroutine to call Read function on websocket connection // this is required to keep the connection functioning go func() { + defer log.Noticef("Deregistered runner with id: %v", runner.id) + defer close(runner.receiveChan) for { msgType, data, err := reg.conn.Read(context.Background()) if err != nil { - // TODO: this is still racy, since a runner could be alloctade between the + // TODO: this is still racy, since a runner could be allocated between the // connection returning an err and the channel closing - close(runner.receiveChan) + // This should probably be handled by sending erroring, but not 100% sure log.Errorf("Could not read from connection: %v", err) - log.Noticef("Deregistering runner with id: %v", runner.id) - return - } else { - log.Debugf("%v: %v", msgType, data) - runner.receiveChan <- RunnerData{msgType: msgType, data: data} } + if msgType != websocket.MessageBinary { + close(runner.receiveChan) + log.Errorf("Got binary data from connection") + return + } + + runner.receiveChan <- data + } }() @@ -170,10 +203,28 @@ func RegisterRunner(conn *websocket.Conn, registerCh chan RunnerRegistration) { var registration RunnerRegistration registration.conn = conn - err := wsjson.Read(ctx, conn, ®istration) + + typ, r, err := conn.Read(ctx) if err != nil { - log.Errorf("Could not read data from websocket connection: %v", err) + log.Errorf("Could not read from runner websocket connection: %v", err) + log.Errorf("Disconnecting...") return } + if typ != websocket.MessageBinary { + log.Error("Got non binary message from runner, disconnecting...") + conn.Close(websocket.StatusUnsupportedData, "Requires binary data") + return + } + registration_proto := &runner_api.Register{} + if err := proto.Unmarshal(r, registration_proto); err != nil { + log.Error("Could not parse registration message from runner, disconnection....") + conn.Close(websocket.StatusUnsupportedData, "Invalid message") + return + } + + registration.Secret = registration_proto.Secret + registration.Id = registration_proto.Id + registration.Tags = registration_proto.Tags + registerCh <- registration }