Implment worker api (fixes #3)

This commit is contained in:
2023-01-09 00:59:19 -07:00
parent f4798233ba
commit 3d7bef3a82
5 changed files with 202 additions and 45 deletions
+73 -5
View File
@@ -1,8 +1,15 @@
package runnermanager
import "nhooyr.io/websocket"
import (
"context"
"fmt"
//var log = logging.MustGetLogger("cursorius-server")
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"nhooyr.io/websocket"
runner_api "git.ohea.xyz/cursorius/runner-api/go/api/v2"
)
type RunnerData struct {
msgType websocket.MessageType
@@ -13,7 +20,7 @@ type Runner struct {
id string
tags []string
conn *websocket.Conn
receiveChan chan RunnerData
receiveChan chan []byte
running bool
}
@@ -21,6 +28,67 @@ func (r *Runner) Id() string {
return r.id
}
func (r *Runner) RunCommand(string) (int64, string, string, error) {
return 0, "", "", nil
func (r *Runner) Release() {
r.running = false
}
func (r *Runner) RunCommand(cmd string, args []string) (returnCode int64, stdout string, stderr string, err error) {
// Write RunCommand message to client
serverToRunnerMsg := &runner_api.ServerToRunnerMsg{
Msg: &runner_api.ServerToRunnerMsg_RunCommandMsg{
RunCommandMsg: &runner_api.RunCommand{
Command: cmd,
Args: args,
},
},
}
err = r.sendProtoStruct(serverToRunnerMsg)
if err != nil {
err = fmt.Errorf("Could not send command to client: %w", err)
return
}
for {
// Read RunCommandFinalResponse message from client
data, ok := <-r.receiveChan
if !ok {
err = fmt.Errorf("Channel is closed on runner")
return
}
runnerToServerMsg := &runner_api.RunnerToServerMsg{}
if err = proto.Unmarshal(data, runnerToServerMsg); err != nil {
err = fmt.Errorf("Could not parse RunCommand response: %w", err)
r.conn.Close(websocket.StatusUnsupportedData, "Invalid message")
return
}
switch x := runnerToServerMsg.Msg.(type) {
case *runner_api.RunnerToServerMsg_RunCommandPartialResponseMsg:
stdout += x.RunCommandPartialResponseMsg.Stdout
stderr += x.RunCommandPartialResponseMsg.Stderr
case *runner_api.RunnerToServerMsg_RunCommandFinalResponseMsg:
stdout += x.RunCommandFinalResponseMsg.PartialResponse.Stdout
stderr += x.RunCommandFinalResponseMsg.PartialResponse.Stderr
returnCode = x.RunCommandFinalResponseMsg.ReturnCode
return
}
}
}
func (r *Runner) sendProtoStruct(p protoreflect.ProtoMessage) error {
protoOut, err := proto.Marshal(p)
if err != nil {
return fmt.Errorf("Could not marshal proto: %w", err)
}
ctx := context.Background()
if err := r.conn.Write(ctx, websocket.MessageBinary, protoOut); err != nil {
return fmt.Errorf("Could not send proto to websocket: %w", err)
}
return nil
}
+76 -25
View File
@@ -6,10 +6,13 @@ import (
"strings"
"time"
"git.ohea.xyz/cursorius/server/config"
"github.com/op/go-logging"
"google.golang.org/protobuf/proto"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
"git.ohea.xyz/cursorius/server/config"
runner_api "git.ohea.xyz/cursorius/runner-api/go/api/v2"
)
var log = logging.MustGetLogger("cursorius-server")
@@ -22,10 +25,11 @@ type RunnerRegistration struct {
}
type runnerManager struct {
getRunnerCh chan GetRunnerRequest
registerCh chan RunnerRegistration
connectedRunners []Runner
configuredRunners map[string]config.Runner
getRunnerCh chan GetRunnerRequest
registerCh chan RunnerRegistration
connectedRunners []Runner
numConnectedRunners uint64
configuredRunners map[string]config.Runner
}
type GetRunnerRequest struct {
@@ -34,7 +38,7 @@ type GetRunnerRequest struct {
}
type GetRunnerResponse struct {
Runner Runner
Runner *Runner
Err error
}
@@ -54,6 +58,10 @@ func (r *runnerManager) processRequest(req GetRunnerRequest) {
log.Debugf("Finding runner with tags %v", runnerTagsStr.String())
foundRunner := false
runnersToRemove := []int{}
runnerIter:
for i, runner := range r.connectedRunners {
// don't allocate runner that is already occupied
if runner.running {
@@ -72,25 +80,45 @@ func (r *runnerManager) processRequest(req GetRunnerRequest) {
continue
}
log.Noticef("Removing defunct runner \"%v\"", runner.id)
// if the receive channel is closed, swap delete the runner as it's defunct
r.connectedRunners[i] = r.connectedRunners[len(r.connectedRunners)-1]
r.connectedRunners = r.connectedRunners[:len(r.connectedRunners)-1]
runnersToRemove = append(runnersToRemove, i)
default:
runner.running = true
log.Debugf("Checking runner %v for requested tags", runner.id)
tagIter:
for _, requestedTag := range req.Tags {
for _, posessedTag := range runner.tags {
if requestedTag == posessedTag {
continue tagIter
}
}
continue runnerIter
}
r.connectedRunners[i].running = true
foundRunner = true
req.RespChan <- GetRunnerResponse{
Runner: runner,
Runner: &r.connectedRunners[i],
Err: nil,
}
return
}
}
// since we iterate, all the indexes will be in accending order
for i, runnerInd := range runnersToRemove {
r.connectedRunners[runnerInd-i] = r.connectedRunners[len(r.connectedRunners)-1]
r.connectedRunners = r.connectedRunners[0 : len(r.connectedRunners)-2]
}
if foundRunner {
return
}
errorMsg := "could not find valid runner"
if len(r.connectedRunners) == 0 {
errorMsg = "no connected runners"
}
log.Errorf("Could not allocate runner with tags \"%v\": %v", runnerTagsStr.String(), errorMsg)
req.RespChan <- GetRunnerResponse{
Runner: Runner{},
Runner: &Runner{},
Err: fmt.Errorf("Could not allocate runner: %v", errorMsg),
}
@@ -105,27 +133,32 @@ func (r *runnerManager) processRegistration(reg RunnerRegistration) {
id: reg.Id,
tags: reg.Tags,
conn: reg.conn,
receiveChan: make(chan RunnerData),
receiveChan: make(chan []byte),
running: false,
}
r.connectedRunners = append(r.connectedRunners, runner)
// start goroutine to call Read function on websocket connection
// this is required to keep the connection functioning
go func() {
defer log.Noticef("Deregistered runner with id: %v", runner.id)
defer close(runner.receiveChan)
for {
msgType, data, err := reg.conn.Read(context.Background())
if err != nil {
// TODO: this is still racy, since a runner could be alloctade between the
// TODO: this is still racy, since a runner could be allocated between the
// connection returning an err and the channel closing
close(runner.receiveChan)
// This should probably be handled by sending erroring, but not 100% sure
log.Errorf("Could not read from connection: %v", err)
log.Noticef("Deregistering runner with id: %v", runner.id)
return
} else {
log.Debugf("%v: %v", msgType, data)
runner.receiveChan <- RunnerData{msgType: msgType, data: data}
}
if msgType != websocket.MessageBinary {
close(runner.receiveChan)
log.Errorf("Got binary data from connection")
return
}
runner.receiveChan <- data
}
}()
@@ -170,10 +203,28 @@ func RegisterRunner(conn *websocket.Conn, registerCh chan RunnerRegistration) {
var registration RunnerRegistration
registration.conn = conn
err := wsjson.Read(ctx, conn, &registration)
typ, r, err := conn.Read(ctx)
if err != nil {
log.Errorf("Could not read data from websocket connection: %v", err)
log.Errorf("Could not read from runner websocket connection: %v", err)
log.Errorf("Disconnecting...")
return
}
if typ != websocket.MessageBinary {
log.Error("Got non binary message from runner, disconnecting...")
conn.Close(websocket.StatusUnsupportedData, "Requires binary data")
return
}
registration_proto := &runner_api.Register{}
if err := proto.Unmarshal(r, registration_proto); err != nil {
log.Error("Could not parse registration message from runner, disconnection....")
conn.Close(websocket.StatusUnsupportedData, "Invalid message")
return
}
registration.Secret = registration_proto.Secret
registration.Id = registration_proto.Id
registration.Tags = registration_proto.Tags
registerCh <- registration
}