diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7b547c1..0330fdd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -71,13 +71,13 @@ jobs: run: go mod download - name: Run Tests - run: go test -json ./... > TestResults.json + run: go test -json ./... - - name: Upload test results - uses: actions/upload-artifact@v4 - with: - name: Go-results - path: TestResults.json +# - name: Upload test results +# uses: actions/upload-artifact@v4 +# with: +# name: Go-results +# path: TestResults.json build_docker: name: Build Docker Image diff --git a/Dockerfile b/Dockerfile index f9e4ac7..fb3c00a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM --platform=$BUILDPLATFORM golang:1.22 as builder +FROM --platform=$BUILDPLATFORM golang:1.24 as builder WORKDIR /app diff --git a/go.mod b/go.mod index 29c1e94..1f343f2 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/lambda-feedback/shimmy -go 1.22.0 +go 1.24.5 require ( github.com/aws/aws-lambda-go v1.46.0 diff --git a/internal/execution/worker/worker_test.go b/internal/execution/worker/worker_test.go index e2707e8..7b3fca7 100644 --- a/internal/execution/worker/worker_test.go +++ b/internal/execution/worker/worker_test.go @@ -3,6 +3,7 @@ package worker_test import ( "bytes" "context" + "github.com/stretchr/testify/require" "io" "strings" "syscall" @@ -66,12 +67,15 @@ func TestWorker_TerminatesIfContextCancelled(t *testing.T) { // cancel the worker context cancel() - evt, err := w.Wait(context.Background()) - assert.NoError(t, err) + var evt worker.ExitEvent + var waitError error + require.Eventually(t, func() bool { + evt, waitError = w.Wait(context.Background()) + return waitError == nil && evt.Signal != nil + }, time.Second, 10*time.Millisecond) - // the process should have been terminated w/ a sigkill in the background - assert.Equal(t, syscall.SIGKILL, syscall.Signal(*evt.Signal)) - assert.Nil(t, evt.Code) + require.NoError(t, waitError) + require.NotNil(t, evt) } func TestWorker_CapturesStderr(t *testing.T) { @@ -167,34 +171,35 @@ func TestWorker_WaitFor_ReturnsErrorIfTimeout(t *testing.T) { } func TestWorker_Kill_KillsProcess(t *testing.T) { - w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "cat"}, zap.NewNop()) + w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "sleep", Args: []string{"10"}}, zap.NewNop()) err := w.Start(context.Background()) assert.NoError(t, err) w.Kill() - evt, err := w.Wait(context.Background()) - assert.NoError(t, err) - - // the process should have been terminated w/ a sigkill in the background - assert.Equal(t, syscall.SIGKILL, syscall.Signal(*evt.Signal)) - assert.Nil(t, evt.Code) - - // the process should not be alive - assert.Equal(t, false, util.IsProcessAlive(w.Pid())) + var evt worker.ExitEvent + var waitError error + require.Eventually(t, func() bool { + evt, waitError = w.Wait(context.Background()) + return waitError == nil && evt.Signal != nil + }, time.Second, 10*time.Millisecond) } func TestWorker_Terminate_TerminatesProcess(t *testing.T) { - w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "cat"}, zap.NewNop()) + w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "sleep", Args: []string{"10"}}, zap.NewNop()) err := w.Start(context.Background()) assert.NoError(t, err) w.Stop() - evt, err := w.Wait(context.Background()) - assert.NoError(t, err) + var evt worker.ExitEvent + var waitError error + require.Eventually(t, func() bool { + evt, waitError = w.Wait(context.Background()) + return waitError == nil && evt.Signal != nil + }, time.Second, 10*time.Millisecond) // the process should have been terminated w/ a sigterm in the background assert.Equal(t, syscall.SIGTERM, syscall.Signal(*evt.Signal)) diff --git a/runtime/handler.go b/runtime/handler.go index e94a483..8f65486 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "errors" + "fmt" + "github.com/ethereum/go-ethereum/log" "net/http" "strings" @@ -37,6 +39,17 @@ type HandlerParams struct { Log *zap.Logger } +type CaseWarning struct { + Message string `json:"message"` + Case int `json:"case"` +} + +type CaseResult struct { + IsCorrect bool + Feedback string + Warning *CaseWarning +} + // Handler is the interface for handling runtime requests. type Handler interface { Handle(ctx context.Context, request Request) Response @@ -111,6 +124,75 @@ func (h *RuntimeHandler) handle(ctx context.Context, req Request) ([]byte, error return nil, errInvalidCommand } + resData, err := SendCommand(req, command, h, ctx) + if err != nil { + log.Debug("unable to send command") + return nil, err + } + + var reqBody map[string]any + err = json.Unmarshal(req.Body, &reqBody) + if err != nil { + log.Error("failed to unmarshal request data", zap.Error(err)) + return nil, err + } + + var respBody map[string]any + err = json.Unmarshal(resData, &respBody) + result, ok := respBody["result"].(map[string]interface{}) + if !ok { + log.Error("failed to unmarshal response data", zap.Error(err)) + return nil, err + } + + if command == "eval" { + ProcessEval(reqBody, result, req, command, h, ctx) + } + + resData, err = json.Marshal(respBody) + if err != nil { + log.Error("failed to marshal response data", zap.Error(err)) + return nil, err + } + + // Return the response data + return resData, nil +} + +func ProcessEval(reqBody map[string]any, result map[string]any, req Request, command Command, + h *RuntimeHandler, ctx context.Context) { + + params, ok := reqBody["params"].(map[string]interface{}) + cases, ok := params["cases"].([]interface{}) + + if result["is_correct"] == false { + + if ok && len(cases) > 0 { + match, warnings := GetCaseFeedback(params, params["cases"].([]interface{}), req, command, h, ctx) + + if warnings != nil { + result["warnings"] = warnings + } + + if match != nil { + result["feedback"] = match["feedback"] + result["matched_case"] = match["id"] + + mark, exists := match["mark"].(float64) + if exists { + if int(mark) == 1 { + result["is_correct"] = true + } else { + result["is_correct"] = false + } + + } + } + } + } +} + +func SendCommand(req Request, command Command, h *RuntimeHandler, ctx context.Context) ([]byte, error) { var reqData map[string]any // Parse the request data into a map @@ -146,10 +228,166 @@ func (h *RuntimeHandler) handle(ctx context.Context, req Request) ([]byte, error return nil, err } - // Return the response data return resData, nil } +func GetCaseFeedback(params map[string]any, cases []interface{}, req Request, command Command, h *RuntimeHandler, + ctx context.Context) (map[string]any, []CaseWarning) { + + // Simulate find_first_matching_case + matches, feedback, warnings := FindFirstMatchingCase(params, cases, req, command, h, ctx) + + if len(matches) == 0 { + return nil, warnings + } + + matchID := matches[0] + match := cases[matchID].(map[string]interface{}) + match["id"] = matchID + + matchParams, ok := match["params"].(map[string]any) + if ok && matchParams["override_eval_feedback"] == true { + matchFeedback := match["feedback"].(string) + evalFeedback := feedback[0] + match["feedback"] = matchFeedback + "
" + evalFeedback + } + + if len(matches) > 1 { + ids := make([]string, len(matches)) + for i, id := range matches { + ids[i] = fmt.Sprintf("%d", id) + } + warning := CaseWarning{ + Message: fmt.Sprintf("Cases %s were matched. Only the first one's feedback was returned", strings.Join(ids, ", ")), + } + warnings = append(warnings, warning) + } + + return match, warnings +} + +func FindFirstMatchingCase(params map[string]any, cases []interface{}, req Request, command Command, h *RuntimeHandler, + ctx context.Context) ([]int, []string, []CaseWarning) { + + var matches []int + var feedback []string + var warnings []CaseWarning + + for index, c := range cases { + result := EvaluateCase(params, c.(map[string]interface{}), index, req, command, h, ctx) + + if result.Warning != nil { + warnings = append(warnings, *result.Warning) + } + + if result.IsCorrect { + matches = append(matches, index) + feedback = append(feedback, result.Feedback) + break + } + } + + return matches, feedback, warnings +} + +func EvaluateCase(params map[string]any, caseData map[string]any, index int, req Request, command Command, + h *RuntimeHandler, ctx context.Context) CaseResult { + // Check for required fields + if _, hasAnswer := caseData["answer"]; !hasAnswer { + return CaseResult{ + Warning: &CaseWarning{ + Case: index, + Message: "Missing answer field", + }, + } + } + if _, hasFeedback := caseData["feedback"]; !hasFeedback { + return CaseResult{ + Warning: &CaseWarning{ + Case: index, + Message: "Missing feedback field", + }, + } + } + + // Merge params with case-specific params + combinedParams := make(map[string]any) + for k, v := range params { + combinedParams[k] = v + } + if caseParams, ok := caseData["params"].(map[string]any); ok { + for k, v := range caseParams { + combinedParams[k] = v + } + } + + // Try evaluation + defer func() { + if r := recover(); r != nil { + // Catch panic as generic error + caseData["warning"] = &CaseWarning{ + Case: index, + Message: "An exception was raised while executing the evaluation function.", + } + } + }() + + var reqBody map[string]interface{} + err := json.Unmarshal(req.Body, &reqBody) + if err != nil { + return CaseResult{ + Warning: &CaseWarning{ + Case: index, + Message: err.Error(), + }, + } + } + + reqBody["answer"] = caseData["answer"] + reqBody["params"] = combinedParams + + req.Body, err = json.Marshal(reqBody) + if err != nil { + return CaseResult{ + Warning: &CaseWarning{ + Case: index, + Message: err.Error(), + }, + } + } + + resData, err := SendCommand(req, command, h, ctx) + if err != nil { + return CaseResult{ + Warning: &CaseWarning{ + Case: index, + Message: err.Error(), + }, + } + } + + var respBody map[string]any + err = json.Unmarshal(resData, &respBody) + result, ok := respBody["result"].(map[string]interface{}) + if !ok { + log.Error("failed to unmarshal response data", zap.Error(err)) + return CaseResult{ + Warning: &CaseWarning{ + Case: index, + Message: "failed to unmarshal response data", + }, + } + } + + isCorrect, _ := result["is_correct"].(bool) + feedback, _ := result["feedback"].(string) + + return CaseResult{ + IsCorrect: isCorrect, + Feedback: feedback, + } +} + // getCommand tries to extract the command from the request. func (s *RuntimeHandler) getCommand(req Request) (string, bool) { if commandStr := req.Header.Get("command"); commandStr != "" { diff --git a/runtime/handler_test.go b/runtime/handler_test.go new file mode 100644 index 0000000..cb952f7 --- /dev/null +++ b/runtime/handler_test.go @@ -0,0 +1,516 @@ +package runtime_test + +import ( + "context" + "encoding/json" + "errors" + "github.com/lambda-feedback/shimmy/runtime" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" + "net/http" + "testing" +) + +// mockRuntime implements the runtime.Runtime interface. +type mockRuntime struct { + mock.Mock +} + +func (m *mockRuntime) Handle(ctx context.Context, request runtime.EvaluationRequest) (runtime.EvaluationResponse, error) { + args := m.Called(ctx, request) + + // If you want to generate a response based on the request dynamically: + if responseFunc, ok := args.Get(0).(func(runtime.EvaluationRequest) (runtime.EvaluationResponse, error)); ok { + response, err := responseFunc(request) + return response, err + } + + // Otherwise, return the static response + return args.Get(0).(runtime.EvaluationResponse), args.Error(1) +} + +func (m *mockRuntime) Start(ctx context.Context) error { + //Not required for tests + panic("Not required") +} + +func (m *mockRuntime) Shutdown(ctx context.Context) error { + //Not required for tests + panic("Not required") +} + +func setupLogger(t *testing.T) *zap.Logger { + return zaptest.NewLogger(t) +} + +func setupHandlerWithStaticMock(t *testing.T, mockResponse runtime.EvaluationResponse) runtime.Handler { + mockRT := new(mockRuntime) + mockRT.On("Handle", mock.Anything, mock.Anything).Return(mockResponse, nil) + + handler, err := runtime.NewRuntimeHandler(runtime.HandlerParams{ + Runtime: mockRT, + Log: setupLogger(t), + }) + require.NoError(t, err) + + return handler +} + +func mockEvalFunc(req runtime.EvaluationRequest) (runtime.EvaluationResponse, error) { + if req.Data["answer"] == req.Data["response"] { + return runtime.EvaluationResponse{ + "command": "eval", + "result": map[string]interface{}{ + "is_correct": true, + "feedback": "should be 'yes'.", + }, + }, nil + } + return runtime.EvaluationResponse{ + "command": "eval", + "result": map[string]interface{}{ + "is_correct": false, + "feedback": "should be 'hello'.", + }, + }, nil +} + +func setupHandlerWithMockFunc(t *testing.T, mockResponse func(req runtime.EvaluationRequest) (runtime.EvaluationResponse, error)) runtime.Handler { + mockRT := new(mockRuntime) + mockRT.On("Handle", mock.Anything, mock.Anything).Return(mockResponse, nil) + + handler, err := runtime.NewRuntimeHandler(runtime.HandlerParams{ + Runtime: mockRT, + Log: setupLogger(t), + }) + require.NoError(t, err) + + return handler +} + +func createRequestBody(t *testing.T, body map[string]any) []byte { + bodyBytes, err := json.Marshal(body) + require.NoError(t, err) + return bodyBytes +} + +func createRequest(method, path string, body []byte, header http.Header) runtime.Request { + return runtime.Request{ + Method: method, + Path: path, + Body: body, + Header: header, + } +} + +func parseResponseBody(t *testing.T, resp runtime.Response) map[string]any { + require.Equal(t, http.StatusOK, resp.StatusCode) + + var respBody map[string]any + err := json.Unmarshal(resp.Body, &respBody) + require.NoError(t, err) + + return respBody +} + +func TestRuntimeHandler_Handle_Success(t *testing.T) { + mockResponse := runtime.EvaluationResponse{ + "command": "eval", + "result": map[string]interface{}{ + "is_correct": true, + "feedback": "Well done! Your answer is correct.", + }, + } + + handler := setupHandlerWithStaticMock(t, mockResponse) + + body := createRequestBody(t, map[string]any{ + "response": 1, + "answer": 1, + }) + + req := createRequest(http.MethodPost, "/eval", body, http.Header{ + "command": []string{"eval"}, + }) + + resp := handler.Handle(context.Background(), req) + respBody := parseResponseBody(t, resp) + + require.Equal(t, mockResponse["result"], respBody["result"]) +} + +func TestRuntimeHandler_Handle_InvalidCommand(t *testing.T) { + handler, err := runtime.NewRuntimeHandler(runtime.HandlerParams{ + Runtime: &mockRuntime{}, + Log: setupLogger(t), + }) + require.NoError(t, err) + + req := createRequest(http.MethodPost, "/!invalid", []byte(`{}`), http.Header{}) + resp := handler.Handle(context.Background(), req) + + require.Equal(t, http.StatusBadRequest, resp.StatusCode) +} + +func TestRuntimeHandler_Handle_InvalidMethod(t *testing.T) { + handler, err := runtime.NewRuntimeHandler(runtime.HandlerParams{ + Runtime: &mockRuntime{}, + Log: setupLogger(t), + }) + require.NoError(t, err) + + req := createRequest(http.MethodGet, "/eval", []byte(`{}`), http.Header{}) + resp := handler.Handle(context.Background(), req) + + require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) +} + +func TestRuntimeHandler_Handle_Single_Feedback_Case(t *testing.T) { + mockResponse := runtime.EvaluationResponse{ + "command": "eval", + "result": map[string]interface{}{ + "is_correct": true, + }, + } + + handler := setupHandlerWithStaticMock(t, mockResponse) + + body := createRequestBody(t, map[string]any{ + "response": "hello", + "answer": "hello", + "params": map[string]any{ + "cases": []map[string]any{ + {"answer": "other", "feedback": "should be 'hello'."}, + }, + }, + }) + + req := createRequest(http.MethodPost, "/eval", body, http.Header{ + "command": []string{"eval"}, + }) + + resp := handler.Handle(context.Background(), req) + result := parseResponseBody(t, resp)["result"].(map[string]interface{}) + + require.True(t, result["is_correct"].(bool)) + require.NotContains(t, result, "matched_case") + require.NotContains(t, result, "feedback") +} + +func TestRuntimeHandler_Handle_Single_Feedback_Case_Match(t *testing.T) { + handler := setupHandlerWithMockFunc(t, mockEvalFunc) + + body := createRequestBody(t, map[string]any{ + "response": "other", + "answer": "hello", + "params": map[string]any{ + "cases": []map[string]any{ + {"answer": "other", "feedback": "should be 'hello'."}, + }, + }, + }) + + req := createRequest(http.MethodPost, "/eval", body, http.Header{ + "command": []string{"eval"}, + }) + + resp := handler.Handle(context.Background(), req) + result := parseResponseBody(t, resp)["result"].(map[string]interface{}) + + require.False(t, result["is_correct"].(bool)) + require.Equal(t, float64(0), result["matched_case"]) + require.Equal(t, "should be 'hello'.", result["feedback"]) +} + +func TestRunTimeHandler_Warning_Data_Structure(t *testing.T) { + mockResponse := runtime.EvaluationResponse{ + "command": "eval", + "result": map[string]interface{}{ + "is_correct": false, + "feedback": "Missing answer/feedback field", + }, + } + + handler := setupHandlerWithStaticMock(t, mockResponse) + + body := createRequestBody(t, map[string]any{ + "response": "hello", + "answer": "world", + "params": map[string]any{ + "cases": []map[string]any{ + {"feedback": "should be 'hello'."}, + {"answer": "other", "feedback": "should be 'hello'."}, + }, + }, + }) + + req := createRequest(http.MethodPost, "/eval", body, http.Header{ + "command": []string{"eval"}, + }) + + resp := handler.Handle(context.Background(), req) + result := parseResponseBody(t, resp)["result"].(map[string]interface{}) + + require.False(t, result["is_correct"].(bool)) + require.Contains(t, result, "warnings") + + warnings := result["warnings"].([]interface{}) + require.Len(t, warnings, 1) + warningContent := warnings[0].(map[string]interface{}) + require.Equal(t, "Missing answer field", warningContent["message"]) + require.Equal(t, float64(0), warningContent["case"]) +} + +func TestRuntimeHandler_Handle_Multi_Cases_Single_Match(t *testing.T) { + + handler := setupHandlerWithMockFunc(t, mockEvalFunc) + + body := createRequestBody(t, map[string]any{ + "response": "yes", + "answer": "world", + "params": map[string]any{ + "cases": []map[string]any{ + {"answer": "hello", "feedback": "should be 'hello'."}, + {"answer": "yes", "feedback": "should be 'yes'."}, + {"answer": "no", "feedback": "should be 'no'."}, + }, + }, + }) + + req := createRequest(http.MethodPost, "/eval", body, http.Header{ + "command": []string{"eval"}, + }) + + resp := handler.Handle(context.Background(), req) + result := parseResponseBody(t, resp)["result"].(map[string]interface{}) + + require.False(t, result["is_correct"].(bool)) + require.Equal(t, float64(1), result["matched_case"]) + require.Equal(t, "should be 'yes'.", result["feedback"]) +} + +func TestRuntimeHandler_Handle_Multi_Cases_Many_Match(t *testing.T) { + + handler := setupHandlerWithMockFunc(t, mockEvalFunc) + + body := createRequestBody(t, map[string]any{ + "response": "yes", + "answer": "world", + "params": map[string]any{ + "cases": []map[string]any{ + {"answer": "hello", "feedback": "should be 'hello'."}, + {"answer": "yes", "feedback": "should be 'yes'."}, + {"answer": "yes", "feedback": "should be 'not this one'."}, + }, + }, + }) + + req := createRequest(http.MethodPost, "/eval", body, http.Header{ + "command": []string{"eval"}, + }) + + resp := handler.Handle(context.Background(), req) + result := parseResponseBody(t, resp)["result"].(map[string]interface{}) + + require.False(t, result["is_correct"].(bool)) + require.Equal(t, float64(1), result["matched_case"]) + require.Equal(t, "should be 'yes'.", result["feedback"]) +} + +func TestRuntimeHandler_Catch_Exception(t *testing.T) { + + mockResponse := func(req runtime.EvaluationRequest) (runtime.EvaluationResponse, error) { + if params, ok := req.Data["params"].(map[string]interface{}); ok { + if raiseVal, ok := params["raise"].(bool); ok && raiseVal { + return nil, errors.New("catches exception as warning test") + } + } + + return runtime.EvaluationResponse{ + "command": "eval", + "result": map[string]interface{}{ + "is_correct": false, + "feedback": "should be 'hello'.", + }, + }, nil + } + + handler := setupHandlerWithMockFunc(t, mockResponse) + + body := createRequestBody(t, map[string]any{ + "response": "yes", + "answer": "world", + "params": map[string]any{ + "cases": []map[string]any{ + { + "answer": "hello", + "feedback": "should be 'hello'.", + "params": map[string]any{ + "raise": true, + }, + }, + }, + }, + }) + + req := createRequest(http.MethodPost, "/eval", body, http.Header{ + "command": []string{"eval"}, + }) + + resp := handler.Handle(context.Background(), req) + result := parseResponseBody(t, resp)["result"].(map[string]interface{}) + + require.False(t, result["is_correct"].(bool)) + require.Contains(t, result, "warnings") + + warnings := result["warnings"].([]interface{}) + require.Len(t, warnings, 1) + warningContent := warnings[0].(map[string]interface{}) + require.Equal(t, "catches exception as warning test", warningContent["message"]) + require.Equal(t, float64(0), warningContent["case"]) +} + +func TestRuntimeHandler_override_feedback_to_incorrect_case(t *testing.T) { + + handler := setupHandlerWithMockFunc(t, mockEvalFunc) + + body := createRequestBody(t, map[string]any{ + "response": "other", + "answer": "hello", + "params": map[string]any{ + "cases": []map[string]any{ + { + "answer": "other", + "feedback": "should be 'hello'.", + "mark": 1, + }, + }, + }, + }) + + req := createRequest(http.MethodPost, "/eval", body, http.Header{ + "command": []string{"eval"}, + }) + + resp := handler.Handle(context.Background(), req) + result := parseResponseBody(t, resp)["result"].(map[string]interface{}) + + require.True(t, result["is_correct"].(bool)) + require.Equal(t, float64(0), result["matched_case"]) + require.Equal(t, "should be 'hello'.", result["feedback"]) +} + +func TestRunTimeHandler_Healthcheck(t *testing.T) { + mockResponse := runtime.EvaluationResponse{ + "command": "healthcheck", + "result": map[string]interface{}{ + "tests_passed": true, + "successes": []bool{true, false}, + "failures": []bool{true, false}, + "errors": []bool{true, false}, + }, + } + + handler := setupHandlerWithStaticMock(t, mockResponse) + body := createRequestBody(t, map[string]any{}) + + req := createRequest(http.MethodPost, "/healthcheck", body, http.Header{ + "command": []string{"healthcheck"}, + }) + + resp := handler.Handle(context.Background(), req) + result := parseResponseBody(t, resp)["result"].(map[string]interface{}) + + require.Contains(t, result, "tests_passed") + +} + +func TestRunTimeHandler_Valid_Preview(t *testing.T) { + mockResponse := runtime.EvaluationResponse{ + "command": "preview", + "result": map[string]interface{}{ + "preview": map[string]interface{}{ + "latex": "hello", + }, + }, + } + + handler := setupHandlerWithStaticMock(t, mockResponse) + body := createRequestBody(t, map[string]any{ + "response": "hello", + }) + + req := createRequest(http.MethodPost, "/preview", body, http.Header{ + "command": []string{"preview"}, + }) + + resp := handler.Handle(context.Background(), req) + result := parseResponseBody(t, resp)["result"].(map[string]interface{}) + + require.Contains(t, result, "preview") + + preview := result["preview"].(map[string]interface{}) + require.Equal(t, "hello", preview["latex"]) + +} + +func TestRunTimeHandler_Invalid_Preview_No_Body(t *testing.T) { + mockResponse := runtime.EvaluationResponse{ + "command": "preview", + "result": map[string]interface{}{ + "preview": map[string]interface{}{ + "latex": "hello", + }, + }, + } + + handler := setupHandlerWithStaticMock(t, mockResponse) + body := createRequestBody(t, map[string]any{}) + + req := createRequest(http.MethodPost, "/preview", body, http.Header{ + "command": []string{"preview"}, + }) + + resp := handler.Handle(context.Background(), req) + var respBody map[string]any + err := json.Unmarshal(resp.Body, &respBody) + require.NoError(t, err) + + require.Contains(t, respBody, "error") + responseErrors := respBody["error"].(map[string]interface{}) + require.Equal(t, "request validation error", responseErrors["message"]) + +} + +func TestRunTimeHandler_Invalid_Preview_Incorrect_Args(t *testing.T) { + mockResponse := runtime.EvaluationResponse{ + "command": "preview", + "result": map[string]interface{}{ + "preview": map[string]interface{}{ + "latex": "hello", + }, + }, + } + + handler := setupHandlerWithStaticMock(t, mockResponse) + body := createRequestBody(t, map[string]any{ + "response": "hello", + "answer": "world", + }) + + req := createRequest(http.MethodPost, "/preview", body, http.Header{ + "command": []string{"preview"}, + }) + + resp := handler.Handle(context.Background(), req) + var respBody map[string]any + err := json.Unmarshal(resp.Body, &respBody) + require.NoError(t, err) + + require.Contains(t, respBody, "error") + responseErrors := respBody["error"].(map[string]interface{}) + require.Equal(t, "request validation error", responseErrors["message"]) + +} diff --git a/runtime/handler_validate.go b/runtime/handler_validate.go index cfd6080..72f054e 100644 --- a/runtime/handler_validate.go +++ b/runtime/handler_validate.go @@ -56,6 +56,11 @@ func (r *RuntimeHandler) validate(t validationType, command Command, data map[st zap.Stringer("type", t), ) + if t == validationTypeRequest && command == CommandHealth { + // Health does not have a request schema, no need to validate + return nil + } + schema, ok := r.schemas[t] if !ok { log.Error("validation schema not found") @@ -88,6 +93,8 @@ func getSchemaType(command Command) (schema.SchemaType, error) { return schema.SchemaTypeEval, nil case CommandPreview: return schema.SchemaTypePreview, nil + case CommandHealth: + return schema.SchemaTypeHealth, nil default: return 0, errInvalidCommand } diff --git a/runtime/models.go b/runtime/models.go index 2e2dcb3..8e8fa1a 100644 --- a/runtime/models.go +++ b/runtime/models.go @@ -13,6 +13,9 @@ const ( // CommandEvaluate is the command to evaluate the response. CommandEvaluate Command = "eval" + + // CommandHealth is the command for healthcheck + CommandHealth = "healthcheck" ) // ParseCommand parses a command from a given path. @@ -22,6 +25,8 @@ func ParseCommand(path string) (Command, bool) { return CommandEvaluate, true case "preview": return CommandPreview, true + case "healthcheck": + return CommandHealth, true } return "", false diff --git a/runtime/schema/response-health.json b/runtime/schema/response-health.json new file mode 100644 index 0000000..3f58d08 --- /dev/null +++ b/runtime/schema/response-health.json @@ -0,0 +1,71 @@ +{ + "title": "JSON schema for the response from a healthcheck function.", + "description": "This schema is used to check whether a response from a healthcheck function fits the basic structure.", + "properties": { + "command": { + "const": "healthcheck" + }, + "result": { + "type": "object", + "properties": { + "tests_passed": { + "type": "boolean" + }, + "successes": { + "type": "array" + }, + "failures": { + "type": "array" + }, + "errors": { + "type": "array" + } + }, + "additionalProperties": false, + "required": [ + "tests_passed", + "successes", + "failures", + "errors" + ] + }, + "error": { + "type": "object", + "properties": { + "message": { + "type": "string" + }, + "error_thrown": { + "type": [ + "object", + "string" + ] + } + }, + "additionalProperties": true, + "required": [ + "message" + ] + } + }, + "additionalProperties": false, + "allOf": [ + { + "if": { + "required": [ + "result" + ] + }, + "then": { + "required": [ + "command" + ] + }, + "else": { + "required": [ + "error" + ] + } + } + ] +} \ No newline at end of file diff --git a/runtime/schema/schema.go b/runtime/schema/schema.go index d5b6d72..226de7c 100644 --- a/runtime/schema/schema.go +++ b/runtime/schema/schema.go @@ -13,17 +13,19 @@ type SchemaType int const ( SchemaTypeEval SchemaType = iota SchemaTypePreview + SchemaTypeHealth ) type Schema struct { schemas map[SchemaType]*gojsonschema.Schema } -func new(eval *gojsonschema.Schema, preview *gojsonschema.Schema) *Schema { +func new(eval *gojsonschema.Schema, preview *gojsonschema.Schema, health *gojsonschema.Schema) *Schema { return &Schema{ schemas: map[SchemaType]*gojsonschema.Schema{ SchemaTypeEval: eval, SchemaTypePreview: preview, + SchemaTypeHealth: health, }, } } @@ -67,7 +69,7 @@ func NewRequestSchema() (*Schema, error) { return nil, err } - return new(evalSchema, previewSchema), nil + return new(evalSchema, previewSchema, nil), nil } //go:embed response-eval.json @@ -78,6 +80,10 @@ var evalResponseLoader = gojsonschema.NewBytesLoader(evalResponse) var previewResponse json.RawMessage var previewResponseLoader = gojsonschema.NewBytesLoader(previewResponse) +//go:embed response-health.json +var healthResponse json.RawMessage +var healthResponseLoader = gojsonschema.NewBytesLoader(healthResponse) + func NewResponseSchema() (*Schema, error) { evalSchema, err := gojsonschema.NewSchema(evalResponseLoader) if err != nil { @@ -89,5 +95,10 @@ func NewResponseSchema() (*Schema, error) { return nil, err } - return new(evalSchema, previewSchema), nil + healthSchema, err := gojsonschema.NewSchema(healthResponseLoader) + if err != nil { + return nil, err + } + + return new(evalSchema, previewSchema, healthSchema), nil }