Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a5551dc
Implemented unit tests for handling single feedback items
m-messer Aug 5, 2025
189f316
Created unit test for single feedback case and switched mocking of ev…
m-messer Aug 6, 2025
0be2bbf
Implemented single feedback matched case unit test
m-messer Aug 7, 2025
0fe6861
Refactored tests to reduce duplicated code
m-messer Aug 7, 2025
8005e2c
Implemented case code from BaseEvaluationFunction and tested warnings
m-messer Aug 8, 2025
9ae23b7
Implemented mock with func and updating request body to send case dat…
m-messer Aug 8, 2025
72b485a
Wrote test to confirm that exceptions are caught as warnings
m-messer Aug 8, 2025
f6b72f9
Refactored mockResponse function to global function
m-messer Aug 8, 2025
c3c709d
Implemented overiding is_correct if mark value is different in the case
m-messer Aug 8, 2025
a98a2e9
Refactored eval only aspects to own function
m-messer Aug 8, 2025
76ef2c3
Implemented healthcheck command
m-messer Aug 8, 2025
f409c27
Implemented preview unit tests from BaseEvalFnLayer
m-messer Aug 8, 2025
ac8d36b
Updated go version to latest
m-messer Aug 8, 2025
cd56be7
Updating workflow to try figure out why tests failing on CI but passi…
m-messer Aug 8, 2025
3ce74db
Fixed go.mod and Dockerfile go version mismatch
m-messer Aug 8, 2025
68d1afe
Fixed wait condition in test worker cancel
m-messer Aug 8, 2025
86c1336
Fixed wait condition in test worker kill process
m-messer Aug 8, 2025
e2ef009
Fixed wait condition in test worker kill process
m-messer Aug 8, 2025
ff39941
Fixed wait condition in test worker kill process
m-messer Aug 8, 2025
7782983
Fixed wait condition in test worker kill process
m-messer Aug 8, 2025
6f13d20
Switched cat command with sleep to hopefully fix race condition
m-messer Aug 8, 2025
0947463
Fixed race condition for terminates process
m-messer Aug 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ jobs:
run: go mod download

- name: Run Tests
run: go test -json ./... > TestResults.json
run: go test -json ./...

- name: Upload test results
uses: actions/upload-artifact@v4
with:
name: Go-results
path: TestResults.json
# - name: Upload test results
# uses: actions/upload-artifact@v4
# with:
# name: Go-results
# path: TestResults.json

build_docker:
name: Build Docker Image
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM --platform=$BUILDPLATFORM golang:1.22 as builder
FROM --platform=$BUILDPLATFORM golang:1.24 as builder

WORKDIR /app

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/lambda-feedback/shimmy

go 1.22.0
go 1.24.5

require (
github.com/aws/aws-lambda-go v1.46.0
Expand Down
41 changes: 23 additions & 18 deletions internal/execution/worker/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package worker_test
import (
"bytes"
"context"
"github.com/stretchr/testify/require"
"io"
"strings"
"syscall"
Expand Down Expand Up @@ -66,12 +67,15 @@ func TestWorker_TerminatesIfContextCancelled(t *testing.T) {
// cancel the worker context
cancel()

evt, err := w.Wait(context.Background())
assert.NoError(t, err)
var evt worker.ExitEvent
var waitError error
require.Eventually(t, func() bool {
evt, waitError = w.Wait(context.Background())
return waitError == nil && evt.Signal != nil
}, time.Second, 10*time.Millisecond)

// the process should have been terminated w/ a sigkill in the background
assert.Equal(t, syscall.SIGKILL, syscall.Signal(*evt.Signal))
assert.Nil(t, evt.Code)
require.NoError(t, waitError)
require.NotNil(t, evt)
}

func TestWorker_CapturesStderr(t *testing.T) {
Expand Down Expand Up @@ -167,34 +171,35 @@ func TestWorker_WaitFor_ReturnsErrorIfTimeout(t *testing.T) {
}

func TestWorker_Kill_KillsProcess(t *testing.T) {
w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "cat"}, zap.NewNop())
w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "sleep", Args: []string{"10"}}, zap.NewNop())

err := w.Start(context.Background())
assert.NoError(t, err)

w.Kill()

evt, err := w.Wait(context.Background())
assert.NoError(t, err)

// the process should have been terminated w/ a sigkill in the background
assert.Equal(t, syscall.SIGKILL, syscall.Signal(*evt.Signal))
assert.Nil(t, evt.Code)

// the process should not be alive
assert.Equal(t, false, util.IsProcessAlive(w.Pid()))
var evt worker.ExitEvent
var waitError error
require.Eventually(t, func() bool {
evt, waitError = w.Wait(context.Background())
return waitError == nil && evt.Signal != nil
}, time.Second, 10*time.Millisecond)
}

func TestWorker_Terminate_TerminatesProcess(t *testing.T) {
w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "cat"}, zap.NewNop())
w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "sleep", Args: []string{"10"}}, zap.NewNop())

err := w.Start(context.Background())
assert.NoError(t, err)

w.Stop()

evt, err := w.Wait(context.Background())
assert.NoError(t, err)
var evt worker.ExitEvent
var waitError error
require.Eventually(t, func() bool {
evt, waitError = w.Wait(context.Background())
return waitError == nil && evt.Signal != nil
}, time.Second, 10*time.Millisecond)

// the process should have been terminated w/ a sigterm in the background
assert.Equal(t, syscall.SIGTERM, syscall.Signal(*evt.Signal))
Expand Down
240 changes: 239 additions & 1 deletion runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/log"
"net/http"
"strings"

Expand Down Expand Up @@ -37,6 +39,17 @@ type HandlerParams struct {
Log *zap.Logger
}

type CaseWarning struct {
Message string `json:"message"`
Case int `json:"case"`
}

type CaseResult struct {
IsCorrect bool
Feedback string
Warning *CaseWarning
}

// Handler is the interface for handling runtime requests.
type Handler interface {
Handle(ctx context.Context, request Request) Response
Expand Down Expand Up @@ -111,6 +124,75 @@ func (h *RuntimeHandler) handle(ctx context.Context, req Request) ([]byte, error
return nil, errInvalidCommand
}

resData, err := SendCommand(req, command, h, ctx)
if err != nil {
log.Debug("unable to send command")
return nil, err
}

var reqBody map[string]any
err = json.Unmarshal(req.Body, &reqBody)
if err != nil {
log.Error("failed to unmarshal request data", zap.Error(err))
return nil, err
}

var respBody map[string]any
err = json.Unmarshal(resData, &respBody)
result, ok := respBody["result"].(map[string]interface{})
if !ok {
log.Error("failed to unmarshal response data", zap.Error(err))
return nil, err
}

if command == "eval" {
ProcessEval(reqBody, result, req, command, h, ctx)
}

resData, err = json.Marshal(respBody)
if err != nil {
log.Error("failed to marshal response data", zap.Error(err))
return nil, err
}

// Return the response data
return resData, nil
}

func ProcessEval(reqBody map[string]any, result map[string]any, req Request, command Command,
h *RuntimeHandler, ctx context.Context) {

params, ok := reqBody["params"].(map[string]interface{})
cases, ok := params["cases"].([]interface{})

if result["is_correct"] == false {

if ok && len(cases) > 0 {
match, warnings := GetCaseFeedback(params, params["cases"].([]interface{}), req, command, h, ctx)

if warnings != nil {
result["warnings"] = warnings
}

if match != nil {
result["feedback"] = match["feedback"]
result["matched_case"] = match["id"]

mark, exists := match["mark"].(float64)
if exists {
if int(mark) == 1 {
result["is_correct"] = true
} else {
result["is_correct"] = false
}

}
}
}
}
}

func SendCommand(req Request, command Command, h *RuntimeHandler, ctx context.Context) ([]byte, error) {
var reqData map[string]any

// Parse the request data into a map
Expand Down Expand Up @@ -146,10 +228,166 @@ func (h *RuntimeHandler) handle(ctx context.Context, req Request) ([]byte, error
return nil, err
}

// Return the response data
return resData, nil
}

func GetCaseFeedback(params map[string]any, cases []interface{}, req Request, command Command, h *RuntimeHandler,
ctx context.Context) (map[string]any, []CaseWarning) {

// Simulate find_first_matching_case
matches, feedback, warnings := FindFirstMatchingCase(params, cases, req, command, h, ctx)

if len(matches) == 0 {
return nil, warnings
}

matchID := matches[0]
match := cases[matchID].(map[string]interface{})
match["id"] = matchID

matchParams, ok := match["params"].(map[string]any)
if ok && matchParams["override_eval_feedback"] == true {
matchFeedback := match["feedback"].(string)
evalFeedback := feedback[0]
match["feedback"] = matchFeedback + "<br />" + evalFeedback
}

if len(matches) > 1 {
ids := make([]string, len(matches))
for i, id := range matches {
ids[i] = fmt.Sprintf("%d", id)
}
warning := CaseWarning{
Message: fmt.Sprintf("Cases %s were matched. Only the first one's feedback was returned", strings.Join(ids, ", ")),
}
warnings = append(warnings, warning)
}

return match, warnings
}

func FindFirstMatchingCase(params map[string]any, cases []interface{}, req Request, command Command, h *RuntimeHandler,
ctx context.Context) ([]int, []string, []CaseWarning) {

var matches []int
var feedback []string
var warnings []CaseWarning

for index, c := range cases {
result := EvaluateCase(params, c.(map[string]interface{}), index, req, command, h, ctx)

if result.Warning != nil {
warnings = append(warnings, *result.Warning)
}

if result.IsCorrect {
matches = append(matches, index)
feedback = append(feedback, result.Feedback)
break
}
}

return matches, feedback, warnings
}

func EvaluateCase(params map[string]any, caseData map[string]any, index int, req Request, command Command,
h *RuntimeHandler, ctx context.Context) CaseResult {
// Check for required fields
if _, hasAnswer := caseData["answer"]; !hasAnswer {
return CaseResult{
Warning: &CaseWarning{
Case: index,
Message: "Missing answer field",
},
}
}
if _, hasFeedback := caseData["feedback"]; !hasFeedback {
return CaseResult{
Warning: &CaseWarning{
Case: index,
Message: "Missing feedback field",
},
}
}

// Merge params with case-specific params
combinedParams := make(map[string]any)
for k, v := range params {
combinedParams[k] = v
}
if caseParams, ok := caseData["params"].(map[string]any); ok {
for k, v := range caseParams {
combinedParams[k] = v
}
}

// Try evaluation
defer func() {
if r := recover(); r != nil {
// Catch panic as generic error
caseData["warning"] = &CaseWarning{
Case: index,
Message: "An exception was raised while executing the evaluation function.",
}
}
}()

var reqBody map[string]interface{}
err := json.Unmarshal(req.Body, &reqBody)
if err != nil {
return CaseResult{
Warning: &CaseWarning{
Case: index,
Message: err.Error(),
},
}
}

reqBody["answer"] = caseData["answer"]
reqBody["params"] = combinedParams

req.Body, err = json.Marshal(reqBody)
if err != nil {
return CaseResult{
Warning: &CaseWarning{
Case: index,
Message: err.Error(),
},
}
}

resData, err := SendCommand(req, command, h, ctx)
if err != nil {
return CaseResult{
Warning: &CaseWarning{
Case: index,
Message: err.Error(),
},
}
}

var respBody map[string]any
err = json.Unmarshal(resData, &respBody)
result, ok := respBody["result"].(map[string]interface{})
if !ok {
log.Error("failed to unmarshal response data", zap.Error(err))
return CaseResult{
Warning: &CaseWarning{
Case: index,
Message: "failed to unmarshal response data",
},
}
}

isCorrect, _ := result["is_correct"].(bool)
feedback, _ := result["feedback"].(string)

return CaseResult{
IsCorrect: isCorrect,
Feedback: feedback,
}
}

// getCommand tries to extract the command from the request.
func (s *RuntimeHandler) getCommand(req Request) (string, bool) {
if commandStr := req.Header.Get("command"); commandStr != "" {
Expand Down
Loading
Loading