From fd0f92513c62759865c28f5502bad7f4cb789c43 Mon Sep 17 00:00:00 2001 From: Brandur Leach Date: Wed, 28 May 2025 08:50:49 -0600 Subject: [PATCH] Allow explicit schema injection to `rivertest.Require*` test functions (#926) Here, resolve #907 by letting an explicit schema be injected into `rivertest.Require*` assertions in a similar way that one can be used in a client. This approach adds a schema in `RequireInsertedOpts`. This comment does a good job of highlight all the potential approaches for adding a schema [1], and unfortunately none of them are all that great. I implemented one other version of this (a variant of option 2 in that list), which as some advantages, but in the end it just ended up ballooning the API out to an uncomfortable degree. The worst part about adding schema to `RequireInsertedOpts` is its interact with the `RequireMany*` functions, where each expectation can set its own schema, and it's not clear what would happen if different expectations set different schemas. I resolved this ambiguity by making it an error to mix and match schemas. Assertions are allowed to send a schema in only the first position like: jobs := requireManyInserted(ctx, bundle.mockT, bundle.driver, []ExpectedJob{ {Args: &Job1Args{String: "foo"}, Opts: bundle.schemaOpts}, {Args: &Job1Args{String: "bar"}}, }) Or send the same schema in all positions: jobs := requireManyInserted(ctx, bundle.mockT, bundle.driver, []ExpectedJob{ {Args: &Job1Args{String: "foo"}, Opts: bundle.schemaOpts}, {Args: &Job1Args{String: "bar"}, Opts: bundle.schemaOpts}, }) But they aren't allowed to set a schema only in position other than the first, or mix and match schemas between expectations. Fixes #907. [1] https://github.com/riverqueue/river/issues/907#issuecomment-2896783495 --- CHANGELOG.md | 1 + client.go | 4 +- client_test.go | 80 ++++++++++++++++++ driver_client_test.go | 27 ++++++ internal/dblist/db_list.go | 35 +++++--- internal/dblist/db_list_test.go | 68 ++++++++++++--- job_list_params.go | 143 +++++++++++++++++++++----------- 7 files changed, 283 insertions(+), 75 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e2b72b6..e81ebe28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - CLI `river migrate-get` now takes a `--schema` option to inject a custom schema into dumped migrations and schema comments are hidden if `--schema` option isn't provided. [PR #903](https://github.com/riverqueue/river/pull/903). - Added `riverlog.NewMiddlewareCustomContext` that makes the use of `riverlog` job-persisted logging possible with non-slog loggers. [PR #919](https://github.com/riverqueue/river/pull/919). - Added `RequireInsertedOpts.Schema`, allowing an explicit schema to be set when asserting on job inserts with `rivertest`. [PR #926](https://github.com/riverqueue/river/pull/926). +- Added `JobListParams.Where`, which provides an escape hatch for job listing that runs arbitrary SQL with named parameters. [PR #933](https://github.com/riverqueue/river/pull/933). ### Changed diff --git a/client.go b/client.go index 2277068d..556982f0 100644 --- a/client.go +++ b/client.go @@ -2017,7 +2017,7 @@ func (c *Client[TTx]) JobList(ctx context.Context, params *JobListParams) (*JobL } params.schema = c.config.Schema - if c.driver.DatabaseName() == databaseNameSQLite && params.metadataFragment != "" { + if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled { return nil, errJobListParamsMetadataNotSupportedSQLite } @@ -2052,7 +2052,7 @@ func (c *Client[TTx]) JobListTx(ctx context.Context, tx TTx, params *JobListPara } params.schema = c.config.Schema - if c.driver.DatabaseName() == databaseNameSQLite && params.metadataFragment != "" { + if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled { return nil, errJobListParamsMetadataNotSupportedSQLite } diff --git a/client_test.go b/client_test.go index 698dfa0a..7437e3a4 100644 --- a/client_test.go +++ b/client_test.go @@ -3776,6 +3776,86 @@ func Test_Client_JobList(t *testing.T) { require.Equal(t, []int64{job3.ID, job2.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID })) }) + t.Run("ArbitraryWhereRawSQL", func(t *testing.T) { + t.Parallel() + + client, bundle := setup(t) + + var ( + job1 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"foo": "bar"}`), Schema: bundle.schema}) + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema}) + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema}) + ) + + listRes, err := client.JobList(ctx, NewJobListParams().Where(`jsonb_path_query_first(metadata, '$.foo') = '"bar"'::jsonb`)) + require.NoError(t, err) + require.Equal(t, []int64{job1.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID })) + }) + + t.Run("ArbitraryWhereNamedParams", func(t *testing.T) { + t.Parallel() + + client, bundle := setup(t) + + var ( + job1 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"foo": "bar"}`), Schema: bundle.schema}) + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema}) + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema}) + ) + + listRes, err := client.JobList(ctx, NewJobListParams().Where("jsonb_path_query_first(metadata, @json_query) = @json_val", NamedArgs{ + "json_query": "$.foo", + "json_val": `"bar"`, + })) + require.NoError(t, err) + require.Equal(t, []int64{job1.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID })) + }) + + t.Run("ArbitraryWhereMultipleNamedParams", func(t *testing.T) { + t.Parallel() + + client, bundle := setup(t) + + var ( + job1 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema}) + job2 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema}) + job3 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema}) + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema}) + ) + + listRes, err := client.JobList(ctx, NewJobListParams().Where("id IN (@id1, @id2, @id3)", + NamedArgs{"id1": job1.ID}, + NamedArgs{"id2": job2.ID}, + NamedArgs{"id3": job3.ID}, + )) + require.NoError(t, err) + require.Equal(t, []int64{job1.ID, job2.ID, job3.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID })) + }) + + t.Run("ArbitraryWhereMultipleClauses", func(t *testing.T) { + t.Parallel() + + client, bundle := setup(t) + + var ( + job = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{ + MaxAttempts: ptrutil.Ptr(27), + Queue: ptrutil.Ptr("custom_queue"), + Schema: bundle.schema, + State: ptrutil.Ptr(rivertype.JobStateDiscarded), + }) + _ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema}) + ) + + listRes, err := client.JobList(ctx, NewJobListParams(). + Where("kind = @kind", NamedArgs{"kind": job.Kind}). + Where("max_attempts = @max_attempts", NamedArgs{"max_attempts": job.MaxAttempts}). + Where("queue = @queue", NamedArgs{"queue": job.Queue}), + ) + require.NoError(t, err) + require.Equal(t, []int64{job.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID })) + }) + t.Run("WithCancelledContext", func(t *testing.T) { t.Parallel() diff --git a/driver_client_test.go b/driver_client_test.go index 248b95d9..eb5e48c8 100644 --- a/driver_client_test.go +++ b/driver_client_test.go @@ -332,6 +332,33 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T, require.Equal(t, job.ID, listRes.Jobs[0].ID) }) + t.Run("JobListTxWhere", func(t *testing.T) { + t.Parallel() + + client, bundle := setup(t) + + tx, execTx := beginTx(ctx, t, bundle) + + job := testfactory.Job(ctx, t, execTx, &testfactory.JobOpts{ + Metadata: []byte(`{"foo":"bar","bar":"baz"}`), + Schema: bundle.schema, + }) + + listParams := NewJobListParams() + + if client.driver.DatabaseName() == databaseNameSQLite { + listParams = listParams.Where("metadata ->> @json_path = @json_val", NamedArgs{"json_path": "$.foo", "json_val": "bar"}) + } else { + // "bar" is quoted in this branch because `jsonb_path_query_first` needs to be compared to a JSON value + listParams = listParams.Where("jsonb_path_query_first(metadata, @json_path) = @json_val", NamedArgs{"json_path": "$.foo", "json_val": `"bar"`}) + } + + listRes, err := client.JobListTx(ctx, tx, listParams) + require.NoError(t, err) + require.Len(t, listRes.Jobs, 1) + require.Equal(t, job.ID, listRes.Jobs[0].ID) + }) + t.Run("QueueGet", func(t *testing.T) { t.Parallel() diff --git a/internal/dblist/db_list.go b/internal/dblist/db_list.go index 48715f51..b221a54d 100644 --- a/internal/dblist/db_list.go +++ b/internal/dblist/db_list.go @@ -25,20 +25,27 @@ type JobListOrderBy struct { } type JobListParams struct { - Conditions string IDs []int64 Kinds []string LimitCount int32 - NamedArgs map[string]any OrderBy []JobListOrderBy Priorities []int16 Queues []string Schema string States []rivertype.JobState + Where []WherePredicate +} + +type WherePredicate struct { + NamedArgs map[string]any + SQL string } func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListParams, sqlFragmentColumnIn func(column string, values any) (string, any, error)) ([]*rivertype.JobRow, error) { - var whereBuilder strings.Builder + var ( + namedArgs = make(map[string]any) + whereBuilder strings.Builder + ) orderBy := make([]JobListOrderBy, len(params.OrderBy)) for i, o := range params.OrderBy { @@ -48,11 +55,6 @@ func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListPara } } - namedArgs := params.NamedArgs - if namedArgs == nil { - namedArgs = make(map[string]any) - } - // Writes an `AND` to connect SQL predicates as long as this isn't the first // predicate. writeAndAfterFirst := func() { @@ -122,9 +124,22 @@ func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListPara namedArgs[column] = arg } - if params.Conditions != "" { + for _, where := range params.Where { writeAndAfterFirst() - whereBuilder.WriteString(params.Conditions) + + whereBuilder.WriteString(where.SQL) + for name, val := range where.NamedArgs { + expectedSymbol := "@" + name + if !strings.Contains(where.SQL, expectedSymbol) { + return nil, fmt.Errorf("expected %q to contain named arg symbol %s", where.SQL, expectedSymbol) + } + + if _, ok := namedArgs[name]; ok { + return nil, fmt.Errorf("named argument %s already registered", expectedSymbol) + } + + namedArgs[name] = val + } } // A condition of some kind is needed, so given no others write one that'll diff --git a/internal/dblist/db_list_test.go b/internal/dblist/db_list_test.go index 02b3a418..a4e4d56d 100644 --- a/internal/dblist/db_list_test.go +++ b/internal/dblist/db_list_test.go @@ -55,11 +55,12 @@ func TestJobListNoJobs(t *testing.T) { bundle := setup() _, err := JobList(ctx, bundle.exec, &JobListParams{ - Conditions: "queue = 'test' AND priority = 1 AND args->>'foo' = @foo", - NamedArgs: pgx.NamedArgs{"foo": "bar"}, States: []rivertype.JobState{rivertype.JobStateCompleted}, LimitCount: 1, OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderAsc}}, + Where: []WherePredicate{ + {NamedArgs: map[string]any{"foo": "bar"}, SQL: "queue = 'test' AND priority = 1 AND args->>'foo' = @foo"}, + }, }, bundle.driver.SQLFragmentColumnIn) require.NoError(t, err) }) @@ -148,11 +149,12 @@ func TestJobListWithJobs(t *testing.T) { bundle := setup(t) params := &JobListParams{ - Conditions: "jsonb_extract_path(args, VARIADIC @paths1::text[]) = @value1::jsonb", LimitCount: 2, - NamedArgs: map[string]any{"paths1": []string{"job_num"}, "value1": 2}, OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}}, States: []rivertype.JobState{rivertype.JobStateAvailable}, + Where: []WherePredicate{ + {NamedArgs: map[string]any{"paths1": []string{"job_num"}, "value1": 2}, SQL: "jsonb_extract_path(args, VARIADIC @paths1::text[]) = @value1::jsonb"}, + }, } execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) { @@ -164,7 +166,7 @@ func TestJobListWithJobs(t *testing.T) { }) }) - t.Run("ConditionsWithIDs", func(t *testing.T) { + t.Run("WhereWithIDs", func(t *testing.T) { t.Parallel() bundle := setup(t) job1, job2, job3 := bundle.jobs[0], bundle.jobs[1], bundle.jobs[2] @@ -188,7 +190,7 @@ func TestJobListWithJobs(t *testing.T) { }) }) - t.Run("ConditionsWithIDsAndPriorities", func(t *testing.T) { + t.Run("WhereWithIDsAndPriorities", func(t *testing.T) { t.Parallel() bundle := setup(t) job1, job2, job3 := bundle.jobs[0], bundle.jobs[1], bundle.jobs[2] @@ -207,17 +209,19 @@ func TestJobListWithJobs(t *testing.T) { }) }) - t.Run("ConditionsWithKinds", func(t *testing.T) { + t.Run("WhereWithKinds", func(t *testing.T) { t.Parallel() bundle := setup(t) params := &JobListParams{ - Conditions: "finalized_at IS NULL", LimitCount: 2, OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}}, Kinds: []string{"alternate_kind"}, States: []rivertype.JobState{rivertype.JobStateAvailable}, + Where: []WherePredicate{ + {SQL: "finalized_at IS NULL"}, + }, } execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) { @@ -229,7 +233,7 @@ func TestJobListWithJobs(t *testing.T) { }) }) - t.Run("ConditionsWithPriorities", func(t *testing.T) { + t.Run("WhereWithPriorities", func(t *testing.T) { t.Parallel() bundle := setup(t) _, job2, job3, _, job5 := bundle.jobs[0], bundle.jobs[1], bundle.jobs[2], bundle.jobs[3], bundle.jobs[4] @@ -246,17 +250,19 @@ func TestJobListWithJobs(t *testing.T) { }) }) - t.Run("ConditionsWithQueues", func(t *testing.T) { + t.Run("WhereWithQueues", func(t *testing.T) { t.Parallel() bundle := setup(t) params := &JobListParams{ - Conditions: "finalized_at IS NULL", LimitCount: 2, OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}}, Queues: []string{"priority"}, States: []rivertype.JobState{rivertype.JobStateAvailable}, + Where: []WherePredicate{ + {SQL: "finalized_at IS NULL"}, + }, } execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) { @@ -274,10 +280,11 @@ func TestJobListWithJobs(t *testing.T) { bundle := setup(t) params := &JobListParams{ - Conditions: "metadata @> @metadata_filter::jsonb", LimitCount: 2, - NamedArgs: map[string]any{"metadata_filter": `{"some_key": "some_value"}`}, OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}}, + Where: []WherePredicate{ + {NamedArgs: map[string]any{"metadata_filter": `{"some_key": "some_value"}`}, SQL: "metadata @> @metadata_filter::jsonb"}, + }, } execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) { @@ -288,4 +295,39 @@ func TestJobListWithJobs(t *testing.T) { require.Equal(t, []int64{job3.ID}, returnedIDs) }) }) + + t.Run("NamedArgNotPresentInQueryError", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + + params := &JobListParams{ + LimitCount: 2, + OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}}, + Where: []WherePredicate{ + {NamedArgs: map[string]any{"not_present": "foo"}, SQL: "1"}, + }, + } + + _, err := JobList(ctx, bundle.exec, params, bundle.driver.SQLFragmentColumnIn) + require.EqualError(t, err, `expected "1" to contain named arg symbol @not_present`) + }) + + t.Run("DuplicateNamedArgError", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + + params := &JobListParams{ + LimitCount: 2, + OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}}, + Where: []WherePredicate{ + {NamedArgs: map[string]any{"duplicate": "foo"}, SQL: "duplicate = @duplicate"}, + {NamedArgs: map[string]any{"duplicate": "foo"}, SQL: "duplicate = @duplicate"}, + }, + } + + _, err := JobList(ctx, bundle.exec, params, bundle.driver.SQLFragmentColumnIn) + require.EqualError(t, err, "named argument @duplicate already registered") + }) } diff --git a/job_list_params.go b/job_list_params.go index d8836d88..2528f57e 100644 --- a/job_list_params.go +++ b/job_list_params.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" "fmt" - "strings" + "maps" "time" "github.com/riverqueue/river/internal/dblist" @@ -165,18 +165,19 @@ const ( // // params := NewJobListParams().OrderBy(JobListOrderByTime, SortOrderAsc).First(100) type JobListParams struct { - after *JobListCursor - ids []int64 - kinds []string - metadataFragment string - overrodeState bool - paginationCount int32 - priorities []int16 - queues []string - schema string - sortField JobListOrderByField - sortOrder SortOrder - states []rivertype.JobState + after *JobListCursor + ids []int64 + kinds []string + metadataCalled bool + overrodeState bool + paginationCount int32 + priorities []int16 + queues []string + schema string + sortField JobListOrderByField + sortOrder SortOrder + states []rivertype.JobState + where []dblist.WherePredicate } // NewJobListParams creates a new JobListParams to return available jobs sorted @@ -201,25 +202,23 @@ func NewJobListParams() *JobListParams { func (p *JobListParams) copy() *JobListParams { return &JobListParams{ - after: p.after, - ids: append([]int64(nil), p.ids...), - kinds: append([]string(nil), p.kinds...), - metadataFragment: p.metadataFragment, - overrodeState: p.overrodeState, - paginationCount: p.paginationCount, - priorities: append([]int16(nil), p.priorities...), - queues: append([]string(nil), p.queues...), - sortField: p.sortField, - sortOrder: p.sortOrder, - schema: p.schema, - states: append([]rivertype.JobState(nil), p.states...), + after: p.after, + ids: append([]int64(nil), p.ids...), + kinds: append([]string(nil), p.kinds...), + metadataCalled: p.metadataCalled, + overrodeState: p.overrodeState, + paginationCount: p.paginationCount, + priorities: append([]int16(nil), p.priorities...), + queues: append([]string(nil), p.queues...), + sortField: p.sortField, + sortOrder: p.sortOrder, + schema: p.schema, + states: append([]rivertype.JobState(nil), p.states...), + where: append([]dblist.WherePredicate(nil), p.where...), } } func (p *JobListParams) toDBParams() (*dblist.JobListParams, error) { - conditionsBuilder := &strings.Builder{} - conditions := make([]string, 0, 10) - namedArgs := make(map[string]any) orderBy := make([]dblist.JobListOrderBy, 0, 2) var sortOrder dblist.SortOrder @@ -265,47 +264,34 @@ func (p *JobListParams) toDBParams() (*dblist.JobListParams, error) { orderBy = append(orderBy, dblist.JobListOrderBy{Expr: "id", Order: sortOrder}) - if p.metadataFragment != "" { - conditions = append(conditions, `metadata @> @metadata_fragment::jsonb`) - namedArgs["metadata_fragment"] = p.metadataFragment - } - if p.after != nil { + namedArgs := map[string]any{"after_id": p.after.id} if p.after.time.IsZero() { // order by ID only if sortOrder == dblist.SortOrderAsc { - conditions = append(conditions, "(id > @after_id)") + p.where = append(p.where, dblist.WherePredicate{NamedArgs: namedArgs, SQL: "(id > @after_id)"}) } else { - conditions = append(conditions, "(id < @after_id)") + p.where = append(p.where, dblist.WherePredicate{NamedArgs: namedArgs, SQL: "(id < @after_id)"}) } } else { + namedArgs["cursor_time"] = p.after.time if sortOrder == dblist.SortOrderAsc { - conditions = append(conditions, fmt.Sprintf(`("%s" > @cursor_time OR ("%s" = @cursor_time AND "id" > @after_id))`, timeField, timeField)) + p.where = append(p.where, dblist.WherePredicate{NamedArgs: namedArgs, SQL: fmt.Sprintf(`("%s" > @cursor_time OR ("%s" = @cursor_time AND "id" > @after_id))`, timeField, timeField)}) } else { - conditions = append(conditions, fmt.Sprintf(`("%s" < @cursor_time OR ("%s" = @cursor_time AND "id" < @after_id))`, timeField, timeField)) + p.where = append(p.where, dblist.WherePredicate{NamedArgs: namedArgs, SQL: fmt.Sprintf(`("%s" < @cursor_time OR ("%s" = @cursor_time AND "id" < @after_id))`, timeField, timeField)}) } - namedArgs["cursor_time"] = p.after.time - } - namedArgs["after_id"] = p.after.id - } - - for i, condition := range conditions { - if i > 0 { - conditionsBuilder.WriteString("\n AND ") } - conditionsBuilder.WriteString(condition) } return &dblist.JobListParams{ - Conditions: conditionsBuilder.String(), IDs: p.ids, Kinds: p.kinds, LimitCount: p.paginationCount, - NamedArgs: namedArgs, OrderBy: orderBy, Priorities: p.priorities, Queues: p.queues, Schema: p.schema, States: p.states, + Where: p.where, }, nil } @@ -362,10 +348,16 @@ func (p *JobListParams) Kinds(kinds ...string) *JobListParams { // https://www.postgresql.org/docs/current/functions-json.html // // This function isn't supported in SQLite due to SQLite not having an -// equivalent operator to use, so there's no efficient way to implement it. +// equivalent operator to use, so there's no efficient way to implement it. We +// recommend the use of Where using a condition with a comparison on the `->>` +// operator instead. func (p *JobListParams) Metadata(json string) *JobListParams { paramsCopy := p.copy() - paramsCopy.metadataFragment = json + paramsCopy.metadataCalled = true + paramsCopy.where = append(paramsCopy.where, dblist.WherePredicate{ + NamedArgs: map[string]any{"metadata_fragment": json}, + SQL: `metadata @> @metadata_fragment::jsonb`, + }) return paramsCopy } @@ -424,6 +416,57 @@ func (p *JobListParams) States(states ...rivertype.JobState) *JobListParams { return paramsCopy } +// NamedArgs are named arguments for use with JobListParams.Where. Keys should +// look like "my_param", and map to parameters like "@my_param" in SQL queries. +// "@" are present in the SQL, but not in the keys of this map. +type NamedArgs map[string]any + +// Where is an all-encompassing query escape hatch that adds an arbitrary +// predicate after a list query's `WHERE ...` clause. Use of other JobListParams +// filters should be preferred where possible because they're safer and their +// compatibility between drivers is better guaranteed, but in case none is +// suitable, Where can be used as a last resort. +// +// For example, using Where to query with `jsonb_path_query_first(...)` using a +// JSON path, a function that's specific to Postgres: +// +// listParams = listParams.Where("jsonb_path_query_first(metadata, @json_path) = @json_val", NamedArgs{"json_path": "$.foo", "json_val": `"bar"`}) +// +// A JSON path can be used in a query in SQLite as well, but there the `->` or +// `->>` operators must be used instead: +// +// listParams = listParams.Where("metadata ->> @json_path = @json_val", NamedArgs{"json_path": "$.foo", "json_val": "bar"}) +// +// Arguments beyond the first are interpreted as named parameters. Each one +// should be present in the query SQL prefixed with a `@` symbol. Multiple sets +// of named parameters will be merged together, with values in later sets +// overwriting those in earlier ones. +// +// Calling Where multiple times will add multiple conditions separate by `AND`. +// Use `OR` instead by stuffing all conditions into a single Where invocation. +// +// Consider use of this function possibly hazardous! Any time raw SQL is in +// play, an application is opening itself up to SQL injection attacks. Never mix +// unsanitized user input into a SQL string, and use named parameters to curb +// the likelihood of injection. +func (p *JobListParams) Where(sql string, namedArgsMany ...NamedArgs) *JobListParams { + paramsCopy := p.copy() + + var allNamedArgs NamedArgs + if len(namedArgsMany) > 0 { + for i, namedArgs := range namedArgsMany { + if i == 0 { + allNamedArgs = namedArgs + } else { + maps.Copy(allNamedArgs, namedArgs) + } + } + } + + paramsCopy.where = append(paramsCopy.where, dblist.WherePredicate{NamedArgs: allNamedArgs, SQL: sql}) + return paramsCopy +} + func jobListTimeFieldForState(state rivertype.JobState) string { // Don't include a `default` so `exhaustive` lint can detect omissions. switch state {