Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- CLI `river migrate-get` now takes a `--schema` option to inject a custom schema into dumped migrations and schema comments are hidden if `--schema` option isn't provided. [PR #903](https://github.com/riverqueue/river/pull/903).
- Added `riverlog.NewMiddlewareCustomContext` that makes the use of `riverlog` job-persisted logging possible with non-slog loggers. [PR #919](https://github.com/riverqueue/river/pull/919).
- Added `RequireInsertedOpts.Schema`, allowing an explicit schema to be set when asserting on job inserts with `rivertest`. [PR #926](https://github.com/riverqueue/river/pull/926).
- Added `JobListParams.Where`, which provides an escape hatch for job listing that runs arbitrary SQL with named parameters. [PR #933](https://github.com/riverqueue/river/pull/933).

### Changed

Expand Down
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2017,7 +2017,7 @@ func (c *Client[TTx]) JobList(ctx context.Context, params *JobListParams) (*JobL
}
params.schema = c.config.Schema

if c.driver.DatabaseName() == databaseNameSQLite && params.metadataFragment != "" {
if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled {
return nil, errJobListParamsMetadataNotSupportedSQLite
}

Expand Down Expand Up @@ -2052,7 +2052,7 @@ func (c *Client[TTx]) JobListTx(ctx context.Context, tx TTx, params *JobListPara
}
params.schema = c.config.Schema

if c.driver.DatabaseName() == databaseNameSQLite && params.metadataFragment != "" {
if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled {
return nil, errJobListParamsMetadataNotSupportedSQLite
}

Expand Down
80 changes: 80 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3776,6 +3776,86 @@ func Test_Client_JobList(t *testing.T) {
require.Equal(t, []int64{job3.ID, job2.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID }))
})

t.Run("ArbitraryWhereRawSQL", func(t *testing.T) {
t.Parallel()

client, bundle := setup(t)

var (
job1 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"foo": "bar"}`), Schema: bundle.schema})
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema})
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema})
)

listRes, err := client.JobList(ctx, NewJobListParams().Where(`jsonb_path_query_first(metadata, '$.foo') = '"bar"'::jsonb`))
require.NoError(t, err)
require.Equal(t, []int64{job1.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID }))
})

t.Run("ArbitraryWhereNamedParams", func(t *testing.T) {
t.Parallel()

client, bundle := setup(t)

var (
job1 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"foo": "bar"}`), Schema: bundle.schema})
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema})
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema})
)

listRes, err := client.JobList(ctx, NewJobListParams().Where("jsonb_path_query_first(metadata, @json_query) = @json_val", NamedArgs{
"json_query": "$.foo",
"json_val": `"bar"`,
}))
require.NoError(t, err)
require.Equal(t, []int64{job1.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID }))
})

t.Run("ArbitraryWhereMultipleNamedParams", func(t *testing.T) {
t.Parallel()

client, bundle := setup(t)

var (
job1 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema})
job2 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema})
job3 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema})
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema})
)

listRes, err := client.JobList(ctx, NewJobListParams().Where("id IN (@id1, @id2, @id3)",
NamedArgs{"id1": job1.ID},
NamedArgs{"id2": job2.ID},
NamedArgs{"id3": job3.ID},
))
require.NoError(t, err)
require.Equal(t, []int64{job1.ID, job2.ID, job3.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID }))
})

t.Run("ArbitraryWhereMultipleClauses", func(t *testing.T) {
t.Parallel()

client, bundle := setup(t)

var (
job = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{
MaxAttempts: ptrutil.Ptr(27),
Queue: ptrutil.Ptr("custom_queue"),
Schema: bundle.schema,
State: ptrutil.Ptr(rivertype.JobStateDiscarded),
})
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema})
)

listRes, err := client.JobList(ctx, NewJobListParams().
Where("kind = @kind", NamedArgs{"kind": job.Kind}).
Where("max_attempts = @max_attempts", NamedArgs{"max_attempts": job.MaxAttempts}).
Where("queue = @queue", NamedArgs{"queue": job.Queue}),
)
require.NoError(t, err)
require.Equal(t, []int64{job.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID }))
})

t.Run("WithCancelledContext", func(t *testing.T) {
t.Parallel()

Expand Down
27 changes: 27 additions & 0 deletions driver_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,33 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T,
require.Equal(t, job.ID, listRes.Jobs[0].ID)
})

t.Run("JobListTxWhere", func(t *testing.T) {
t.Parallel()

client, bundle := setup(t)

tx, execTx := beginTx(ctx, t, bundle)

job := testfactory.Job(ctx, t, execTx, &testfactory.JobOpts{
Metadata: []byte(`{"foo":"bar","bar":"baz"}`),
Schema: bundle.schema,
})

listParams := NewJobListParams()

if client.driver.DatabaseName() == databaseNameSQLite {
listParams = listParams.Where("metadata ->> @json_path = @json_val", NamedArgs{"json_path": "$.foo", "json_val": "bar"})
} else {
// "bar" is quoted in this branch because `jsonb_path_query_first` needs to be compared to a JSON value
listParams = listParams.Where("jsonb_path_query_first(metadata, @json_path) = @json_val", NamedArgs{"json_path": "$.foo", "json_val": `"bar"`})
}

listRes, err := client.JobListTx(ctx, tx, listParams)
require.NoError(t, err)
require.Len(t, listRes.Jobs, 1)
require.Equal(t, job.ID, listRes.Jobs[0].ID)
})

t.Run("QueueGet", func(t *testing.T) {
t.Parallel()

Expand Down
35 changes: 25 additions & 10 deletions internal/dblist/db_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,27 @@ type JobListOrderBy struct {
}

type JobListParams struct {
Conditions string
IDs []int64
Kinds []string
LimitCount int32
NamedArgs map[string]any
OrderBy []JobListOrderBy
Priorities []int16
Queues []string
Schema string
States []rivertype.JobState
Where []WherePredicate
}

type WherePredicate struct {
NamedArgs map[string]any
SQL string
}

func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListParams, sqlFragmentColumnIn func(column string, values any) (string, any, error)) ([]*rivertype.JobRow, error) {
var whereBuilder strings.Builder
var (
namedArgs = make(map[string]any)
whereBuilder strings.Builder
)

orderBy := make([]JobListOrderBy, len(params.OrderBy))
for i, o := range params.OrderBy {
Expand All @@ -48,11 +55,6 @@ func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListPara
}
}

namedArgs := params.NamedArgs
if namedArgs == nil {
namedArgs = make(map[string]any)
}

// Writes an `AND` to connect SQL predicates as long as this isn't the first
// predicate.
writeAndAfterFirst := func() {
Expand Down Expand Up @@ -122,9 +124,22 @@ func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListPara
namedArgs[column] = arg
}

if params.Conditions != "" {
for _, where := range params.Where {
writeAndAfterFirst()
whereBuilder.WriteString(params.Conditions)

whereBuilder.WriteString(where.SQL)
for name, val := range where.NamedArgs {
expectedSymbol := "@" + name
if !strings.Contains(where.SQL, expectedSymbol) {
return nil, fmt.Errorf("expected %q to contain named arg symbol %s", where.SQL, expectedSymbol)
}

if _, ok := namedArgs[name]; ok {
return nil, fmt.Errorf("named argument %s already registered", expectedSymbol)
}

namedArgs[name] = val
}
}

// A condition of some kind is needed, so given no others write one that'll
Expand Down
68 changes: 55 additions & 13 deletions internal/dblist/db_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ func TestJobListNoJobs(t *testing.T) {
bundle := setup()

_, err := JobList(ctx, bundle.exec, &JobListParams{
Conditions: "queue = 'test' AND priority = 1 AND args->>'foo' = @foo",
NamedArgs: pgx.NamedArgs{"foo": "bar"},
States: []rivertype.JobState{rivertype.JobStateCompleted},
LimitCount: 1,
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderAsc}},
Where: []WherePredicate{
{NamedArgs: map[string]any{"foo": "bar"}, SQL: "queue = 'test' AND priority = 1 AND args->>'foo' = @foo"},
},
}, bundle.driver.SQLFragmentColumnIn)
require.NoError(t, err)
})
Expand Down Expand Up @@ -148,11 +149,12 @@ func TestJobListWithJobs(t *testing.T) {
bundle := setup(t)

params := &JobListParams{
Conditions: "jsonb_extract_path(args, VARIADIC @paths1::text[]) = @value1::jsonb",
LimitCount: 2,
NamedArgs: map[string]any{"paths1": []string{"job_num"}, "value1": 2},
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
States: []rivertype.JobState{rivertype.JobStateAvailable},
Where: []WherePredicate{
{NamedArgs: map[string]any{"paths1": []string{"job_num"}, "value1": 2}, SQL: "jsonb_extract_path(args, VARIADIC @paths1::text[]) = @value1::jsonb"},
},
}

execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) {
Expand All @@ -164,7 +166,7 @@ func TestJobListWithJobs(t *testing.T) {
})
})

t.Run("ConditionsWithIDs", func(t *testing.T) {
t.Run("WhereWithIDs", func(t *testing.T) {
t.Parallel()
bundle := setup(t)
job1, job2, job3 := bundle.jobs[0], bundle.jobs[1], bundle.jobs[2]
Expand All @@ -188,7 +190,7 @@ func TestJobListWithJobs(t *testing.T) {
})
})

t.Run("ConditionsWithIDsAndPriorities", func(t *testing.T) {
t.Run("WhereWithIDsAndPriorities", func(t *testing.T) {
t.Parallel()
bundle := setup(t)
job1, job2, job3 := bundle.jobs[0], bundle.jobs[1], bundle.jobs[2]
Expand All @@ -207,17 +209,19 @@ func TestJobListWithJobs(t *testing.T) {
})
})

t.Run("ConditionsWithKinds", func(t *testing.T) {
t.Run("WhereWithKinds", func(t *testing.T) {
t.Parallel()

bundle := setup(t)

params := &JobListParams{
Conditions: "finalized_at IS NULL",
LimitCount: 2,
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
Kinds: []string{"alternate_kind"},
States: []rivertype.JobState{rivertype.JobStateAvailable},
Where: []WherePredicate{
{SQL: "finalized_at IS NULL"},
},
}

execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) {
Expand All @@ -229,7 +233,7 @@ func TestJobListWithJobs(t *testing.T) {
})
})

t.Run("ConditionsWithPriorities", func(t *testing.T) {
t.Run("WhereWithPriorities", func(t *testing.T) {
t.Parallel()
bundle := setup(t)
_, job2, job3, _, job5 := bundle.jobs[0], bundle.jobs[1], bundle.jobs[2], bundle.jobs[3], bundle.jobs[4]
Expand All @@ -246,17 +250,19 @@ func TestJobListWithJobs(t *testing.T) {
})
})

t.Run("ConditionsWithQueues", func(t *testing.T) {
t.Run("WhereWithQueues", func(t *testing.T) {
t.Parallel()

bundle := setup(t)

params := &JobListParams{
Conditions: "finalized_at IS NULL",
LimitCount: 2,
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
Queues: []string{"priority"},
States: []rivertype.JobState{rivertype.JobStateAvailable},
Where: []WherePredicate{
{SQL: "finalized_at IS NULL"},
},
}

execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) {
Expand All @@ -274,10 +280,11 @@ func TestJobListWithJobs(t *testing.T) {
bundle := setup(t)

params := &JobListParams{
Conditions: "metadata @> @metadata_filter::jsonb",
LimitCount: 2,
NamedArgs: map[string]any{"metadata_filter": `{"some_key": "some_value"}`},
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
Where: []WherePredicate{
{NamedArgs: map[string]any{"metadata_filter": `{"some_key": "some_value"}`}, SQL: "metadata @> @metadata_filter::jsonb"},
},
}

execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) {
Expand All @@ -288,4 +295,39 @@ func TestJobListWithJobs(t *testing.T) {
require.Equal(t, []int64{job3.ID}, returnedIDs)
})
})

t.Run("NamedArgNotPresentInQueryError", func(t *testing.T) {
t.Parallel()

bundle := setup(t)

params := &JobListParams{
LimitCount: 2,
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
Where: []WherePredicate{
{NamedArgs: map[string]any{"not_present": "foo"}, SQL: "1"},
},
}

_, err := JobList(ctx, bundle.exec, params, bundle.driver.SQLFragmentColumnIn)
require.EqualError(t, err, `expected "1" to contain named arg symbol @not_present`)
})

t.Run("DuplicateNamedArgError", func(t *testing.T) {
t.Parallel()

bundle := setup(t)

params := &JobListParams{
LimitCount: 2,
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
Where: []WherePredicate{
{NamedArgs: map[string]any{"duplicate": "foo"}, SQL: "duplicate = @duplicate"},
{NamedArgs: map[string]any{"duplicate": "foo"}, SQL: "duplicate = @duplicate"},
},
}

_, err := JobList(ctx, bundle.exec, params, bundle.driver.SQLFragmentColumnIn)
require.EqualError(t, err, "named argument @duplicate already registered")
})
}
Loading
Loading