Implment worker api (fixes #3)

This commit is contained in:
2023-01-09 00:59:19 -07:00
parent f4798233ba
commit 3d7bef3a82
5 changed files with 202 additions and 45 deletions
+3 -2
View File
@@ -3,7 +3,8 @@ module git.ohea.xyz/cursorius/server
go 1.19
require (
git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230102084147-6d988866458e
git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230109075652-ead0aeff2eb9
git.ohea.xyz/cursorius/runner-api/go/api/v2 v2.0.0-20230109074922-e20285fe6cf2
git.ohea.xyz/golang/config v0.0.0-20220915224621-b9debd233173
github.com/bufbuild/connect-go v1.4.1
github.com/docker/docker v20.10.22+incompatible
@@ -12,6 +13,7 @@ require (
github.com/google/uuid v1.3.0
github.com/op/go-logging v0.0.0-20160315200505-970db520ece7
golang.org/x/net v0.2.0
google.golang.org/protobuf v1.28.1
nhooyr.io/websocket v1.8.7
)
@@ -50,7 +52,6 @@ require (
golang.org/x/sys v0.2.0 // indirect
golang.org/x/text v0.4.0 // indirect
golang.org/x/tools v0.1.12 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.4.0 // indirect
+4 -2
View File
@@ -78,8 +78,10 @@ contrib.go.opencensus.io/exporter/stackdriver v0.13.5/go.mod h1:aXENhDJ1Y4lIg4EU
contrib.go.opencensus.io/integrations/ocsql v0.1.4/go.mod h1:8DsSdjz3F+APR+0z0WkU1aRorQCFfRxvqjUUPMbF3fE=
contrib.go.opencensus.io/resource v0.1.1/go.mod h1:F361eGI91LCmW1I/Saf+rX0+OFcigGlFvXwEGEnkRLA=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230102084147-6d988866458e h1:Rdxx7d/iYk+nazy0/Z0Pp1f9XKiJ6b/3nyiw3MXOfHg=
git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230102084147-6d988866458e/go.mod h1:D7GGcFIH421mo6KuRaXXXmlXPwWeEsemTZG/BOZA/4o=
git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230109075652-ead0aeff2eb9 h1:8p7Kw3B7dbi2zdgG+Me9ETRWrJzoNVjcase4YqXfGbs=
git.ohea.xyz/cursorius/pipeline-api/go/api/v2 v2.0.0-20230109075652-ead0aeff2eb9/go.mod h1:D7GGcFIH421mo6KuRaXXXmlXPwWeEsemTZG/BOZA/4o=
git.ohea.xyz/cursorius/runner-api/go/api/v2 v2.0.0-20230109074922-e20285fe6cf2 h1:G1XQEqhj1LZPQbH7avzvT7QL9Wfbb4CXMm0nLL39eDc=
git.ohea.xyz/cursorius/runner-api/go/api/v2 v2.0.0-20230109074922-e20285fe6cf2/go.mod h1:F9y5Ck4Wchsaj5amSX2eDRUlQ/iYP1VNLFduvjNwmLc=
git.ohea.xyz/cursorius/webhooks/v6 v6.0.2-0.20221224221147-a2bdbf1756ed h1:gsK15m4Npow74+R6OfZKwwAg1sl7QWQCRXOeE2QLUco=
git.ohea.xyz/cursorius/webhooks/v6 v6.0.2-0.20221224221147-a2bdbf1756ed/go.mod h1:64JKTmG3kupV+3+ZYJYPB/rGPEKw/diihhIj8lut4UA=
git.ohea.xyz/golang/config v0.0.0-20220915224621-b9debd233173 h1:dhq/W6sa5KkLHVBwwgcNIPWcO4YK2/ecFTTln2W+1n8=
+46 -11
View File
@@ -24,25 +24,32 @@ type ApiServer struct {
}
type RunnerWrapper struct {
runner runnermanager.Runner
runner *runnermanager.Runner
mutex sync.Mutex
}
func (r *RunnerWrapper) RunCommand(cmd string) (int64, string, string, error) {
r.mutex.Unlock()
defer r.mutex.Lock()
func (r *RunnerWrapper) RunCommand(cmd string, args []string) (int64, string, string, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
return_code, stdout, stderr, err := r.runner.RunCommand(cmd)
return_code, stdout, stderr, err := r.runner.RunCommand(cmd, args)
// TODO: run command by sending websocket packet
// TODO: get stdout and stderr response
return return_code, stdout, stderr, err
}
func (s *ApiServer) GetRunnerFromMap(u uuid.UUID) *RunnerWrapper {
func (r *RunnerWrapper) Release() {
r.mutex.Lock()
defer r.mutex.Unlock()
r.runner.Release()
}
func (s *ApiServer) GetRunnerFromMap(u uuid.UUID) (*RunnerWrapper, bool) {
s.allocatedRunnersMutex.RLock()
defer s.allocatedRunnersMutex.RUnlock()
return s.allocatedRunners[u]
runner, ok := s.allocatedRunners[u]
return runner, ok
}
func (s *ApiServer) GetRunner(
@@ -69,7 +76,7 @@ func (s *ApiServer) GetRunner(
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("Could not get runner"))
}
log.Info("Got runner with tags: %v", runnerTagsStr)
log.Infof("Got runner with tags: %v", runnerTagsStr.String())
runnerUuid := uuid.New()
@@ -84,6 +91,29 @@ func (s *ApiServer) GetRunner(
return res, nil
}
func (s *ApiServer) ReleaseRunner(
ctx context.Context,
req *connect.Request[apiv2.ReleaseRunnerRequest],
) (*connect.Response[apiv2.ReleaseRunnerResponse], error) {
uuid, err := uuid.Parse(req.Msg.Id)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("Invalid runner id"))
}
log.Infof("Releasing runner with ID \"%v\"", uuid)
s.allocatedRunnersMutex.Lock()
runner := s.allocatedRunners[uuid]
delete(s.allocatedRunners, uuid)
runner.Release()
s.allocatedRunnersMutex.Unlock()
res := connect.NewResponse(&apiv2.ReleaseRunnerResponse{})
res.Header().Set("ReleaseRunner-Version", "v2")
return res, nil
}
func (s *ApiServer) RunCommand(
ctx context.Context,
req *connect.Request[apiv2.RunCommandRequest],
@@ -94,16 +124,19 @@ func (s *ApiServer) RunCommand(
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("Invalid runner id"))
}
runner := s.GetRunnerFromMap(uuid)
runner, ok := s.GetRunnerFromMap(uuid)
if !ok {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("Invalid runner id"))
}
return_code, stdout, stderr, err := runner.RunCommand(req.Msg.Command)
returnCode, stdout, stderr, err := runner.RunCommand(req.Msg.Command, req.Msg.Args)
if err != nil {
log.Errorf("Could not run command on runner \"%v\", %v", runner.runner.Id(), err)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("Could not run command"))
}
res := connect.NewResponse(&apiv2.RunCommandResponse{
ReturnCode: return_code,
ReturnCode: returnCode,
Stdout: stdout,
Stderr: stderr,
})
@@ -118,6 +151,8 @@ func CreateHandler(mux *http.ServeMux, getRunnerCh chan runnermanager.GetRunnerR
}
path, handler := apiv2connect.NewGetRunnerServiceHandler(api_server)
mux.Handle(path, handler)
path, handler = apiv2connect.NewReleaseRunnerServiceHandler(api_server)
mux.Handle(path, handler)
path, handler = apiv2connect.NewRunCommandServiceHandler(api_server)
mux.Handle(path, handler)
}
+73 -5
View File
@@ -1,8 +1,15 @@
package runnermanager
import "nhooyr.io/websocket"
import (
"context"
"fmt"
//var log = logging.MustGetLogger("cursorius-server")
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"nhooyr.io/websocket"
runner_api "git.ohea.xyz/cursorius/runner-api/go/api/v2"
)
type RunnerData struct {
msgType websocket.MessageType
@@ -13,7 +20,7 @@ type Runner struct {
id string
tags []string
conn *websocket.Conn
receiveChan chan RunnerData
receiveChan chan []byte
running bool
}
@@ -21,6 +28,67 @@ func (r *Runner) Id() string {
return r.id
}
func (r *Runner) RunCommand(string) (int64, string, string, error) {
return 0, "", "", nil
func (r *Runner) Release() {
r.running = false
}
func (r *Runner) RunCommand(cmd string, args []string) (returnCode int64, stdout string, stderr string, err error) {
// Write RunCommand message to client
serverToRunnerMsg := &runner_api.ServerToRunnerMsg{
Msg: &runner_api.ServerToRunnerMsg_RunCommandMsg{
RunCommandMsg: &runner_api.RunCommand{
Command: cmd,
Args: args,
},
},
}
err = r.sendProtoStruct(serverToRunnerMsg)
if err != nil {
err = fmt.Errorf("Could not send command to client: %w", err)
return
}
for {
// Read RunCommandFinalResponse message from client
data, ok := <-r.receiveChan
if !ok {
err = fmt.Errorf("Channel is closed on runner")
return
}
runnerToServerMsg := &runner_api.RunnerToServerMsg{}
if err = proto.Unmarshal(data, runnerToServerMsg); err != nil {
err = fmt.Errorf("Could not parse RunCommand response: %w", err)
r.conn.Close(websocket.StatusUnsupportedData, "Invalid message")
return
}
switch x := runnerToServerMsg.Msg.(type) {
case *runner_api.RunnerToServerMsg_RunCommandPartialResponseMsg:
stdout += x.RunCommandPartialResponseMsg.Stdout
stderr += x.RunCommandPartialResponseMsg.Stderr
case *runner_api.RunnerToServerMsg_RunCommandFinalResponseMsg:
stdout += x.RunCommandFinalResponseMsg.PartialResponse.Stdout
stderr += x.RunCommandFinalResponseMsg.PartialResponse.Stderr
returnCode = x.RunCommandFinalResponseMsg.ReturnCode
return
}
}
}
func (r *Runner) sendProtoStruct(p protoreflect.ProtoMessage) error {
protoOut, err := proto.Marshal(p)
if err != nil {
return fmt.Errorf("Could not marshal proto: %w", err)
}
ctx := context.Background()
if err := r.conn.Write(ctx, websocket.MessageBinary, protoOut); err != nil {
return fmt.Errorf("Could not send proto to websocket: %w", err)
}
return nil
}
+76 -25
View File
@@ -6,10 +6,13 @@ import (
"strings"
"time"
"git.ohea.xyz/cursorius/server/config"
"github.com/op/go-logging"
"google.golang.org/protobuf/proto"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
"git.ohea.xyz/cursorius/server/config"
runner_api "git.ohea.xyz/cursorius/runner-api/go/api/v2"
)
var log = logging.MustGetLogger("cursorius-server")
@@ -22,10 +25,11 @@ type RunnerRegistration struct {
}
type runnerManager struct {
getRunnerCh chan GetRunnerRequest
registerCh chan RunnerRegistration
connectedRunners []Runner
configuredRunners map[string]config.Runner
getRunnerCh chan GetRunnerRequest
registerCh chan RunnerRegistration
connectedRunners []Runner
numConnectedRunners uint64
configuredRunners map[string]config.Runner
}
type GetRunnerRequest struct {
@@ -34,7 +38,7 @@ type GetRunnerRequest struct {
}
type GetRunnerResponse struct {
Runner Runner
Runner *Runner
Err error
}
@@ -54,6 +58,10 @@ func (r *runnerManager) processRequest(req GetRunnerRequest) {
log.Debugf("Finding runner with tags %v", runnerTagsStr.String())
foundRunner := false
runnersToRemove := []int{}
runnerIter:
for i, runner := range r.connectedRunners {
// don't allocate runner that is already occupied
if runner.running {
@@ -72,25 +80,45 @@ func (r *runnerManager) processRequest(req GetRunnerRequest) {
continue
}
log.Noticef("Removing defunct runner \"%v\"", runner.id)
// if the receive channel is closed, swap delete the runner as it's defunct
r.connectedRunners[i] = r.connectedRunners[len(r.connectedRunners)-1]
r.connectedRunners = r.connectedRunners[:len(r.connectedRunners)-1]
runnersToRemove = append(runnersToRemove, i)
default:
runner.running = true
log.Debugf("Checking runner %v for requested tags", runner.id)
tagIter:
for _, requestedTag := range req.Tags {
for _, posessedTag := range runner.tags {
if requestedTag == posessedTag {
continue tagIter
}
}
continue runnerIter
}
r.connectedRunners[i].running = true
foundRunner = true
req.RespChan <- GetRunnerResponse{
Runner: runner,
Runner: &r.connectedRunners[i],
Err: nil,
}
return
}
}
// since we iterate, all the indexes will be in accending order
for i, runnerInd := range runnersToRemove {
r.connectedRunners[runnerInd-i] = r.connectedRunners[len(r.connectedRunners)-1]
r.connectedRunners = r.connectedRunners[0 : len(r.connectedRunners)-2]
}
if foundRunner {
return
}
errorMsg := "could not find valid runner"
if len(r.connectedRunners) == 0 {
errorMsg = "no connected runners"
}
log.Errorf("Could not allocate runner with tags \"%v\": %v", runnerTagsStr.String(), errorMsg)
req.RespChan <- GetRunnerResponse{
Runner: Runner{},
Runner: &Runner{},
Err: fmt.Errorf("Could not allocate runner: %v", errorMsg),
}
@@ -105,27 +133,32 @@ func (r *runnerManager) processRegistration(reg RunnerRegistration) {
id: reg.Id,
tags: reg.Tags,
conn: reg.conn,
receiveChan: make(chan RunnerData),
receiveChan: make(chan []byte),
running: false,
}
r.connectedRunners = append(r.connectedRunners, runner)
// start goroutine to call Read function on websocket connection
// this is required to keep the connection functioning
go func() {
defer log.Noticef("Deregistered runner with id: %v", runner.id)
defer close(runner.receiveChan)
for {
msgType, data, err := reg.conn.Read(context.Background())
if err != nil {
// TODO: this is still racy, since a runner could be alloctade between the
// TODO: this is still racy, since a runner could be allocated between the
// connection returning an err and the channel closing
close(runner.receiveChan)
// This should probably be handled by sending erroring, but not 100% sure
log.Errorf("Could not read from connection: %v", err)
log.Noticef("Deregistering runner with id: %v", runner.id)
return
} else {
log.Debugf("%v: %v", msgType, data)
runner.receiveChan <- RunnerData{msgType: msgType, data: data}
}
if msgType != websocket.MessageBinary {
close(runner.receiveChan)
log.Errorf("Got binary data from connection")
return
}
runner.receiveChan <- data
}
}()
@@ -170,10 +203,28 @@ func RegisterRunner(conn *websocket.Conn, registerCh chan RunnerRegistration) {
var registration RunnerRegistration
registration.conn = conn
err := wsjson.Read(ctx, conn, &registration)
typ, r, err := conn.Read(ctx)
if err != nil {
log.Errorf("Could not read data from websocket connection: %v", err)
log.Errorf("Could not read from runner websocket connection: %v", err)
log.Errorf("Disconnecting...")
return
}
if typ != websocket.MessageBinary {
log.Error("Got non binary message from runner, disconnecting...")
conn.Close(websocket.StatusUnsupportedData, "Requires binary data")
return
}
registration_proto := &runner_api.Register{}
if err := proto.Unmarshal(r, registration_proto); err != nil {
log.Error("Could not parse registration message from runner, disconnection....")
conn.Close(websocket.StatusUnsupportedData, "Invalid message")
return
}
registration.Secret = registration_proto.Secret
registration.Id = registration_proto.Id
registration.Tags = registration_proto.Tags
registerCh <- registration
}