diff --git a/CHANGELOG.md b/CHANGELOG.md index 213f2fd9..33031e34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - CLI `river migrate-get` now takes a `--schema` option to inject a custom schema into dumped migrations and schema comments are hidden if `--schema` option isn't provided. [PR #903](https://github.com/riverqueue/river/pull/903). - Added `riverlog.NewMiddlewareCustomContext` that makes the use of `riverlog` job-persisted logging possible with non-slog loggers. [PR #919](https://github.com/riverqueue/river/pull/919). - Added `RequireInsertedOpts.Schema`, allowing an explicit schema to be set when asserting on job inserts with `rivertest`. [PR #926](https://github.com/riverqueue/river/pull/926). +- When using a driver that doesn't support listen/notify, producers within same process are notified immediately of new job inserts and queue changes (e.g. pause/resume) without having to poll when non-transactional variants are used (i.e. `Insert` instead of `InsertTx`). [PR #928](https://github.com/riverqueue/river/pull/928). - Added `JobListParams.Where`, which provides an escape hatch for job listing that runs arbitrary SQL with named parameters. [PR #933](https://github.com/riverqueue/river/pull/933). ### Changed diff --git a/client.go b/client.go index 25b4d4f1..8a432444 100644 --- a/client.go +++ b/client.go @@ -23,6 +23,7 @@ import ( "github.com/riverqueue/river/internal/notifier" "github.com/riverqueue/river/internal/notifylimiter" "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/internal/util/dbutil" "github.com/riverqueue/river/internal/workunit" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/baseservice" @@ -823,7 +824,10 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client { periodicJobEnqueuer := maintenance.NewPeriodicJobEnqueuer(archetype, &maintenance.PeriodicJobEnqueuerConfig{ AdvisoryLockPrefix: config.AdvisoryLockPrefix, - Insert: client.insertMany, + Insert: func(ctx context.Context, execTx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error { + _, err := client.insertMany(ctx, execTx, insertParams) + return err + }, }, driver.GetExecutor()) maintenanceServices = append(maintenanceServices, periodicJobEnqueuer) client.testSignals.periodicJobEnqueuer = &periodicJobEnqueuer.TestSignals @@ -1272,7 +1276,18 @@ func (c *Client[TTx]) Driver() riverdriver.Driver[TTx] { // Returns the up-to-date JobRow for the specified jobID if it exists. Returns // ErrNotFound if the job doesn't exist. func (c *Client[TTx]) JobCancel(ctx context.Context, jobID int64) (*rivertype.JobRow, error) { - return c.jobCancel(ctx, c.driver.GetExecutor(), jobID) + job, err := c.jobCancel(ctx, c.driver.GetExecutor(), jobID) + if err != nil { + return nil, err + } + + c.notifyProducerWithoutListenerQueueControlEvent(job.Queue, &controlEventPayload{ + Action: controlActionCancel, + JobID: job.ID, + Queue: job.Queue, + }) + + return job, nil } // JobCancelTx cancels the job with the given ID within the specified @@ -1535,21 +1550,16 @@ func (c *Client[TTx]) Insert(ctx context.Context, args JobArgs, opts *InsertOpts return nil, errNoDriverDBPool } - tx, err := c.driver.GetExecutor().Begin(ctx) + res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) (*insertManySharedResult, error) { + return c.validateParamsAndInsertMany(ctx, execTx, []InsertManyParams{{Args: args, InsertOpts: opts}}) + }) if err != nil { return nil, err } - defer tx.Rollback(ctx) - inserted, err := c.insert(ctx, tx, args, opts) - if err != nil { - return nil, err - } + c.notifyProducerWithoutListenerJobFetch(res.QueuesDeduped) - if err := tx.Commit(ctx); err != nil { - return nil, err - } - return inserted, nil + return res.InsertResults[0], nil } // InsertTx inserts a new job with the provided args on the given transaction. @@ -1570,17 +1580,11 @@ func (c *Client[TTx]) Insert(ctx context.Context, args JobArgs, opts *InsertOpts // transactions, the job will not be worked until the transaction has committed, // and if the transaction rolls back, so too is the inserted job. func (c *Client[TTx]) InsertTx(ctx context.Context, tx TTx, args JobArgs, opts *InsertOpts) (*rivertype.JobInsertResult, error) { - return c.insert(ctx, c.driver.UnwrapExecutor(tx), args, opts) -} - -func (c *Client[TTx]) insert(ctx context.Context, tx riverdriver.ExecutorTx, args JobArgs, opts *InsertOpts) (*rivertype.JobInsertResult, error) { - params := []InsertManyParams{{Args: args, InsertOpts: opts}} - results, err := c.validateParamsAndInsertMany(ctx, tx, params) + res, err := c.validateParamsAndInsertMany(ctx, c.driver.UnwrapExecutor(tx), []InsertManyParams{{Args: args, InsertOpts: opts}}) if err != nil { return nil, err } - - return results[0], nil + return res.InsertResults[0], nil } // InsertManyParams encapsulates a single job combined with insert options for @@ -1612,21 +1616,16 @@ func (c *Client[TTx]) InsertMany(ctx context.Context, params []InsertManyParams) return nil, errNoDriverDBPool } - tx, err := c.driver.GetExecutor().Begin(ctx) + res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) (*insertManySharedResult, error) { + return c.validateParamsAndInsertMany(ctx, execTx, params) + }) if err != nil { return nil, err } - defer tx.Rollback(ctx) - inserted, err := c.validateParamsAndInsertMany(ctx, tx, params) - if err != nil { - return nil, err - } + c.notifyProducerWithoutListenerJobFetch(res.QueuesDeduped) - if err := tx.Commit(ctx); err != nil { - return nil, err - } - return inserted, nil + return res.InsertResults, nil } // InsertManyTx inserts many jobs at once. Each job is inserted as an @@ -1648,28 +1647,31 @@ func (c *Client[TTx]) InsertMany(ctx context.Context, params []InsertManyParams) // changes. An inserted job isn't visible to be worked until the transaction // commits, and if the transaction rolls back, so too is the inserted job. func (c *Client[TTx]) InsertManyTx(ctx context.Context, tx TTx, params []InsertManyParams) ([]*rivertype.JobInsertResult, error) { - exec := c.driver.UnwrapExecutor(tx) - return c.validateParamsAndInsertMany(ctx, exec, params) + res, err := c.validateParamsAndInsertMany(ctx, c.driver.UnwrapExecutor(tx), params) + if err != nil { + return nil, err + } + return res.InsertResults, nil } // validateParamsAndInsertMany is a helper method that wraps the insertMany // method to provide param validation and conversion prior to calling the actual // insertMany method. This allows insertMany to be reused by the // PeriodicJobEnqueuer which cannot reference top-level river package types. -func (c *Client[TTx]) validateParamsAndInsertMany(ctx context.Context, tx riverdriver.ExecutorTx, params []InsertManyParams) ([]*rivertype.JobInsertResult, error) { +func (c *Client[TTx]) validateParamsAndInsertMany(ctx context.Context, execTx riverdriver.ExecutorTx, params []InsertManyParams) (*insertManySharedResult, error) { insertParams, err := c.insertManyParams(params) if err != nil { return nil, err } - return c.insertMany(ctx, tx, insertParams) + return c.insertMany(ctx, execTx, insertParams) } // insertMany is a shared code path for InsertMany and InsertManyTx, also used // by the PeriodicJobEnqueuer. -func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) { - return c.insertManyShared(ctx, tx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) { - results, err := c.pilot.JobInsertMany(ctx, tx, &riverdriver.JobInsertFastManyParams{ +func (c *Client[TTx]) insertMany(ctx context.Context, execTx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) (*insertManySharedResult, error) { + return c.insertManyShared(ctx, execTx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) { + results, err := c.pilot.JobInsertMany(ctx, execTx, &riverdriver.JobInsertFastManyParams{ Jobs: insertParams, Schema: c.config.Schema, }) @@ -1685,6 +1687,11 @@ func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx, }) } +type insertManySharedResult struct { + InsertResults []*rivertype.JobInsertResult + QueuesDeduped []string +} + // The shared code path for all Insert and InsertMany methods. It takes a // function that executes the actual insert operation and allows for different // implementations of the insert query to be passed in, each mapping their @@ -1694,7 +1701,9 @@ func (c *Client[TTx]) insertManyShared( tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams, execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error), -) ([]*rivertype.JobInsertResult, error) { +) (*insertManySharedResult, error) { + var queuesDeduped []string + doInner := func(ctx context.Context) ([]*rivertype.JobInsertResult, error) { for _, params := range insertParams { for _, hook := range append( @@ -1710,9 +1719,10 @@ func (c *Client[TTx]) insertManyShared( finalInsertParams := sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams { return (*riverdriver.JobInsertFastParams)(params) }) - results, err := execute(ctx, finalInsertParams) + + insertResults, err := execute(ctx, finalInsertParams) if err != nil { - return results, err + return insertResults, err } queues := make([]string, 0, 10) @@ -1721,10 +1731,14 @@ func (c *Client[TTx]) insertManyShared( queues = append(queues, params.Queue) } } - if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil { + + queuesDeduped = sliceutil.Uniq(queues) + + if err = c.maybeNotifyInsertForQueues(ctx, tx, queuesDeduped); err != nil { return nil, err } - return results, nil + + return insertResults, nil } jobInsertMiddleware := c.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindJobInsert) @@ -1740,7 +1754,15 @@ func (c *Client[TTx]) insertManyShared( } } - return doInner(ctx) + insertResults, err := doInner(ctx) + if err != nil { + return nil, err + } + + return &insertManySharedResult{ + InsertResults: insertResults, + QueuesDeduped: queuesDeduped, + }, nil } // Validates input parameters for a batch insert operation and generates a set @@ -1767,6 +1789,25 @@ func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*rivertype. return insertParams, nil } +// Notifies an internal producer of new jobs being queued for work. Only +// invoked if the client's driver doesn't support a listener. If a listener is +// supported, job notifications go out via listen/notify instead. +// +// Should only ever be invoked *outside* a transaction. If invoked within a +// transaction, the producer wouldn't yet be able to access the new jobs that +// triggered the notification because they're not committed yet. +func (c *Client[TTx]) notifyProducerWithoutListenerJobFetch(queuesDeduped []string) { + if c.driver.SupportsListener() || len(c.producersByQueueName) < 1 { + return + } + + for _, queue := range queuesDeduped { + if producer, ok := c.producersByQueueName[queue]; ok { + producer.TriggerJobFetch() + } + } +} + // InsertManyFast inserts many jobs at once using Postgres' `COPY FROM` mechanism, // making the operation quite fast and memory efficient. Each job is inserted as // an InsertManyParams tuple, which takes job args along with an optional set of @@ -1791,20 +1832,16 @@ func (c *Client[TTx]) InsertManyFast(ctx context.Context, params []InsertManyPar } // Wrap in a transaction in case we need to notify about inserts. - tx, err := c.driver.GetExecutor().Begin(ctx) + res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) (*insertManySharedResult, error) { + return c.insertManyFast(ctx, execTx, params) + }) if err != nil { return 0, err } - defer tx.Rollback(ctx) - inserted, err := c.insertManyFast(ctx, tx, params) - if err != nil { - return 0, err - } - if err := tx.Commit(ctx); err != nil { - return 0, err - } - return inserted, nil + c.notifyProducerWithoutListenerJobFetch(res.QueuesDeduped) + + return len(res.InsertResults), nil } // InsertManyTx inserts many jobs at once using Postgres' `COPY FROM` mechanism, @@ -1831,17 +1868,20 @@ func (c *Client[TTx]) InsertManyFast(ctx context.Context, params []InsertManyPar // a unique constraint is violated, the operation will fail and no jobs will be // inserted. func (c *Client[TTx]) InsertManyFastTx(ctx context.Context, tx TTx, params []InsertManyParams) (int, error) { - exec := c.driver.UnwrapExecutor(tx) - return c.insertManyFast(ctx, exec, params) + res, err := c.insertManyFast(ctx, c.driver.UnwrapExecutor(tx), params) + if err != nil { + return 0, err + } + return len(res.InsertResults), nil } -func (c *Client[TTx]) insertManyFast(ctx context.Context, execTx riverdriver.ExecutorTx, params []InsertManyParams) (int, error) { +func (c *Client[TTx]) insertManyFast(ctx context.Context, execTx riverdriver.ExecutorTx, params []InsertManyParams) (*insertManySharedResult, error) { insertParams, err := c.insertManyParams(params) if err != nil { - return 0, err + return nil, err } - results, err := c.insertManyShared(ctx, execTx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) { + return c.insertManyShared(ctx, execTx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) { count, err := execTx.JobInsertFastManyNoReturning(ctx, &riverdriver.JobInsertFastManyParams{ Jobs: insertParams, Schema: c.config.Schema, @@ -1851,34 +1891,25 @@ func (c *Client[TTx]) insertManyFast(ctx context.Context, execTx riverdriver.Exe } return make([]*rivertype.JobInsertResult, count), nil }) - if err != nil { - return 0, err - } - - return len(results), nil } // Notify the given queues that new jobs are available. The queues list will be // deduplicated and each will be checked to see if it is due for an insert // notification from this client. -func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdriver.ExecutorTx, queues []string) error { - if len(queues) < 1 { +func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdriver.ExecutorTx, queuesDeduped []string) error { + if len(queuesDeduped) < 1 { return nil } - queueMap := make(map[string]struct{}) - queuesDeduped := make([]string, 0, len(queues)) - payloads := make([]string, 0, len(queues)) - - for _, queue := range queues { - if _, ok := queueMap[queue]; ok { - continue - } + var ( + payloads = make([]string, 0, len(queuesDeduped)) + queuesTriggered = make([]string, 0, len(queuesDeduped)) + ) - queueMap[queue] = struct{}{} + for _, queue := range queuesDeduped { if c.insertNotifyLimiter.ShouldTrigger(queue) { payloads = append(payloads, fmt.Sprintf("{\"queue\": %q}", queue)) - queuesDeduped = append(queuesDeduped, queue) + queuesTriggered = append(queuesTriggered, queue) } } @@ -1896,7 +1927,7 @@ func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdr c.baseService.Logger.ErrorContext( ctx, c.baseService.Name+": Failed to send job insert notification", - slog.String("queues", strings.Join(queuesDeduped, ",")), + slog.String("queues", strings.Join(queuesTriggered, ",")), slog.String("err", err.Error()), ) return err @@ -1907,7 +1938,7 @@ func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdr } // emit a notification about a queue being paused or resumed. -func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriver.ExecutorTx, action controlAction, queue string, opts *QueuePauseOpts) error { +func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriver.ExecutorTx, action controlAction, queue string, opts *QueuePauseOpts) (*controlEventPayload, error) { c.baseService.Logger.DebugContext(ctx, c.baseService.Name+": Notifying about queue state change", slog.String("action", string(action)), @@ -1915,9 +1946,11 @@ func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriv slog.String("opts", fmt.Sprintf("%+v", opts)), ) - payload, err := json.Marshal(controlEventPayload{Action: action, Queue: queue}) + controlEvent := &controlEventPayload{Action: action, Queue: queue} + + payload, err := json.Marshal(controlEvent) if err != nil { - return err + return nil, err } if c.driver.SupportsListenNotify() { @@ -1932,11 +1965,11 @@ func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriv c.baseService.Name+": Failed to send queue state change notification", slog.String("err", err.Error()), ) - return err + return nil, err } } - return nil + return controlEvent, nil } // Validates job args prior to insertion. Currently, verifies that a worker to @@ -2221,11 +2254,18 @@ func (c *Client[TTx]) QueuePause(ctx context.Context, name string, opts *QueuePa return err } - if err := c.notifyQueuePauseOrResume(ctx, tx, controlActionPause, name, opts); err != nil { + controlEvent, err := c.notifyQueuePauseOrResume(ctx, tx, controlActionPause, name, opts) + if err != nil { + return err + } + + if err = tx.Commit(ctx); err != nil { return err } - return tx.Commit(ctx) + c.notifyProducerWithoutListenerQueueControlEvent(name, controlEvent) + + return nil } // QueuePauseTx pauses the queue with the given name. When a queue is paused, @@ -2249,7 +2289,7 @@ func (c *Client[TTx]) QueuePauseTx(ctx context.Context, tx TTx, name string, opt return err } - if err := c.notifyQueuePauseOrResume(ctx, executorTx, controlActionPause, name, opts); err != nil { + if _, err := c.notifyQueuePauseOrResume(ctx, executorTx, controlActionPause, name, opts); err != nil { return err } @@ -2282,11 +2322,18 @@ func (c *Client[TTx]) QueueResume(ctx context.Context, name string, opts *QueueP return err } - if err := c.notifyQueuePauseOrResume(ctx, tx, controlActionResume, name, opts); err != nil { + controlEvent, err := c.notifyQueuePauseOrResume(ctx, tx, controlActionResume, name, opts) + if err != nil { return err } - return tx.Commit(ctx) + if err = tx.Commit(ctx); err != nil { + return err + } + + c.notifyProducerWithoutListenerQueueControlEvent(name, controlEvent) + + return nil } // QueueResume resumes the queue with the given name. If the queue was @@ -2311,7 +2358,7 @@ func (c *Client[TTx]) QueueResumeTx(ctx context.Context, tx TTx, name string, op return err } - if err := c.notifyQueuePauseOrResume(ctx, executorTx, controlActionResume, name, opts); err != nil { + if _, err := c.notifyQueuePauseOrResume(ctx, executorTx, controlActionResume, name, opts); err != nil { return err } @@ -2334,7 +2381,7 @@ func (c *Client[TTx]) QueueUpdate(ctx context.Context, name string, params *Queu } defer tx.Rollback(ctx) - queue, err := c.queueUpdate(ctx, tx, name, params) + queue, controlEvent, err := c.queueUpdate(ctx, tx, name, params) if err != nil { return nil, err } @@ -2343,18 +2390,39 @@ func (c *Client[TTx]) QueueUpdate(ctx context.Context, name string, params *Queu return nil, err } + c.notifyProducerWithoutListenerQueueControlEvent(name, controlEvent) + return queue, nil } // QueueUpdateTx updates a queue's settings in the database. These settings // override the settings in the client (if applied). func (c *Client[TTx]) QueueUpdateTx(ctx context.Context, tx TTx, name string, params *QueueUpdateParams) (*rivertype.Queue, error) { - executorTx := c.driver.UnwrapExecutor(tx) + queue, _, err := c.queueUpdate(ctx, c.driver.UnwrapExecutor(tx), name, params) + if err != nil { + return nil, err + } + return queue, nil +} - return c.queueUpdate(ctx, executorTx, name, params) +// Notifies an internal producer of a queue control event like pause/resume. +// Only invoked if the client's driver doesn't support a listener. If a listener +// is supported, control events go out via listen/notify instead. +// +// Should only ever be invoked *outside* a transaction. If invoked within a +// transaction, the producer wouldn't yet be able to access the state that +// triggered the notification because it's not committed yet. +func (c *Client[TTx]) notifyProducerWithoutListenerQueueControlEvent(queue string, controlEvent *controlEventPayload) { + if c.driver.SupportsListener() || len(c.producersByQueueName) < 1 { + return + } + + if producer, ok := c.producersByQueueName[queue]; ok { + producer.TriggerQueueControlEvent(controlEvent) + } } -func (c *Client[TTx]) queueUpdate(ctx context.Context, executorTx riverdriver.ExecutorTx, name string, params *QueueUpdateParams) (*rivertype.Queue, error) { +func (c *Client[TTx]) queueUpdate(ctx context.Context, executorTx riverdriver.ExecutorTx, name string, params *QueueUpdateParams) (*rivertype.Queue, *controlEventPayload, error) { updateMetadata := len(params.Metadata) > 0 queue, err := executorTx.QueueUpdate(ctx, &riverdriver.QueueUpdateParams{ @@ -2364,31 +2432,35 @@ func (c *Client[TTx]) queueUpdate(ctx context.Context, executorTx riverdriver.Ex Schema: c.config.Schema, }) if err != nil { - return nil, err + return nil, nil, err } - if updateMetadata { - payload, err := json.Marshal(controlEventPayload{ - Action: controlActionMetadataChanged, - Metadata: params.Metadata, - Queue: queue.Name, - }) - if err != nil { - return nil, err - } + if !updateMetadata { + return queue, nil, err + } - if c.driver.SupportsListenNotify() { - if err := executorTx.NotifyMany(ctx, &riverdriver.NotifyManyParams{ - Payload: []string{string(payload)}, - Schema: c.config.Schema, - Topic: string(notifier.NotificationTopicControl), - }); err != nil { - return nil, err - } + controlEvent := &controlEventPayload{ + Action: controlActionMetadataChanged, + Metadata: params.Metadata, + Queue: queue.Name, + } + + payload, err := json.Marshal(controlEvent) + if err != nil { + return nil, nil, err + } + + if c.driver.SupportsListenNotify() { + if err := executorTx.NotifyMany(ctx, &riverdriver.NotifyManyParams{ + Payload: []string{string(payload)}, + Schema: c.config.Schema, + Topic: string(notifier.NotificationTopicControl), + }); err != nil { + return nil, nil, err } } - return queue, nil + return queue, controlEvent, nil } // QueueBundle is a bundle for adding additional queues. It's made accessible diff --git a/client_test.go b/client_test.go index 783ed24b..c18267ba 100644 --- a/client_test.go +++ b/client_test.go @@ -36,6 +36,7 @@ import ( "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverdatabasesql" "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/riverdriver/riversqlite" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/riversharedtest" "github.com/riverqueue/river/rivershared/startstoptest" @@ -134,7 +135,6 @@ func newTestConfig(t *testing.T, schema string) *Config { }, TestOnly: true, // disables staggered start in maintenance services Workers: workers, - queuePollInterval: 50 * time.Millisecond, schedulerInterval: riverinternaltest.SchedulerShortInterval, } } @@ -608,6 +608,69 @@ func Test_Client(t *testing.T) { require.NoError(t, err) }) + t.Run("CancelProducerControlEventSent", func(t *testing.T) { + t.Parallel() + + var ( + driver = riversqlite.New(nil) + schema = riverdbtest.TestSchema(ctx, t, driver, &riverdbtest.TestSchemaOpts{ + ProcurePool: func(ctx context.Context, schema string) (any, string) { + return riversharedtest.DBPoolSQLite(ctx, t, schema), "" // could also be `main` instead of empty string + }, + }) + config = newTestConfig(t, schema) + ) + + client, err := NewClient(driver, config) + require.NoError(t, err) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + type JobArgs struct { + JobArgsReflectKind[JobArgs] + } + + AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { + return nil + })) + + startClient(ctx, t, client) + + insertRes, err := client.Insert(ctx, &JobArgs{}, nil) + require.NoError(t, err) + + _, err = client.JobCancel(ctx, insertRes.Job.ID) + require.NoError(t, err) + + controlEvent := client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.WaitOrTimeout() + require.NotNil(t, controlEvent) + require.Equal(t, controlActionCancel, controlEvent.Action) + }) + + t.Run("CancelProducerControlEventNotSent", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + type JobArgs struct { + JobArgsReflectKind[JobArgs] + } + + AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { + return nil + })) + + startClient(ctx, t, client) + + insertRes, err := client.Insert(ctx, &JobArgs{}, nil) + require.NoError(t, err) + + _, err = client.JobCancel(ctx, insertRes.Job.ID) + require.NoError(t, err) + + client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.RequireEmpty() + }) + t.Run("AlternateSchema", func(t *testing.T) { t.Parallel() @@ -1161,6 +1224,55 @@ func Test_Client(t *testing.T) { } }) + t.Run("QueuePauseAndResumeProducerControlEventSent", func(t *testing.T) { + t.Parallel() + + var ( + driver = riversqlite.New(nil) + schema = riverdbtest.TestSchema(ctx, t, driver, &riverdbtest.TestSchemaOpts{ + ProcurePool: func(ctx context.Context, schema string) (any, string) { + return riversharedtest.DBPoolSQLite(ctx, t, schema), "" // could also be `main` instead of empty string + }, + }) + config = newTestConfig(t, schema) + ) + + client, err := NewClient(driver, config) + require.NoError(t, err) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + require.NoError(t, client.QueuePause(ctx, QueueDefault, nil)) + + controlEvent := client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.WaitOrTimeout() + require.NotNil(t, controlEvent) + require.Equal(t, controlActionPause, controlEvent.Action) + + require.NoError(t, client.QueueResume(ctx, QueueDefault, nil)) + + controlEvent = client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.WaitOrTimeout() + require.NotNil(t, controlEvent) + require.Equal(t, controlActionResume, controlEvent.Action) + }) + + t.Run("QueuePauseAndResumeProducerControlEventNotSent", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + require.NoError(t, client.QueuePause(ctx, QueueDefault, nil)) + + client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.RequireEmpty() + + require.NoError(t, client.QueueResume(ctx, QueueDefault, nil)) + + client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.RequireEmpty() + }) + t.Run("PollOnlyDriver", func(t *testing.T) { t.Parallel() @@ -1172,7 +1284,6 @@ func Test_Client(t *testing.T) { client, err := NewClient(riverdatabasesql.New(stdPool), config) require.NoError(t, err) - client.testSignals.Init(t) // Notifier should not have been initialized at all. @@ -2018,6 +2129,46 @@ func Test_Client_Insert(t *testing.T) { require.Equal(t, []string{}, jobRow.Tags) }) + t.Run("ProducerFetchLimiterCalled", func(t *testing.T) { + t.Parallel() + + var ( + driver = riversqlite.New(nil) + schema = riverdbtest.TestSchema(ctx, t, driver, &riverdbtest.TestSchemaOpts{ + ProcurePool: func(ctx context.Context, schema string) (any, string) { + return riversharedtest.DBPoolSQLite(ctx, t, schema), "" // could also be `main` instead of empty string + }, + }) + config = newTestConfig(t, schema) + ) + + client, err := NewClient(driver, config) + require.NoError(t, err) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + _, err = client.Insert(ctx, &noOpArgs{}, nil) + require.NoError(t, err) + + client.producersByQueueName[QueueDefault].testSignals.JobFetchTriggered.WaitOrTimeout() + }) + + // Not called for drivers that support a listener. + t.Run("ProducerFetchLimiterNotCalled", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + _, err := client.Insert(ctx, &noOpArgs{}, nil) + require.NoError(t, err) + + client.producersByQueueName[QueueDefault].testSignals.JobFetchTriggered.RequireEmpty() + }) + t.Run("WithInsertOpts", func(t *testing.T) { t.Parallel() @@ -4948,6 +5099,52 @@ func Test_Client_QueueUpdate(t *testing.T) { case <-time.After(100 * time.Millisecond): } }) + + t.Run("ProducerControlEventSent", func(t *testing.T) { + t.Parallel() + + var ( + driver = riversqlite.New(nil) + schema = riverdbtest.TestSchema(ctx, t, driver, &riverdbtest.TestSchemaOpts{ + ProcurePool: func(ctx context.Context, schema string) (any, string) { + return riversharedtest.DBPoolSQLite(ctx, t, schema), "" // could also be `main` instead of empty string + }, + }) + config = newTestConfig(t, schema) + ) + + client, err := NewClient(driver, config) + require.NoError(t, err) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + _, err = client.QueueUpdate(ctx, QueueDefault, &QueueUpdateParams{ + Metadata: []byte(`{"foo":"baz"}`), + }) + require.NoError(t, err) + + controlEvent := client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.WaitOrTimeout() + require.NotNil(t, controlEvent) + require.Equal(t, controlActionMetadataChanged, controlEvent.Action) + }) + + t.Run("ProducerControlEventNotSent", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + _, err := client.QueueUpdate(ctx, QueueDefault, &QueueUpdateParams{ + Metadata: []byte(`{"foo":"baz"}`), + }) + require.NoError(t, err) + + client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.RequireEmpty() + }) } func Test_Client_QueueUpdateTx(t *testing.T) { diff --git a/example_sqlite_test.go b/example_sqlite_test.go index 976dc7a6..694e95a2 100644 --- a/example_sqlite_test.go +++ b/example_sqlite_test.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "log/slog" - "time" "github.com/riverqueue/river" "github.com/riverqueue/river/riverdriver" @@ -36,9 +35,7 @@ func Example_sqlite() { river.AddWorker(workers, &SortWorker{}) riverClient, err := river.NewClient(driver, &river.Config{ - FetchCooldown: 20 * time.Millisecond, - FetchPollInterval: 50 * time.Millisecond, // this driver is poll only, so speed up poll interval so the test runs fater - Logger: slog.New(&slogutil.SlogMessageOnlyHandler{Level: slog.LevelWarn}), + Logger: slog.New(&slogutil.SlogMessageOnlyHandler{Level: slog.LevelWarn}), Queues: map[string]river.QueueConfig{ river.QueueDefault: {MaxWorkers: 100}, }, diff --git a/internal/maintenance/job_scheduler.go b/internal/maintenance/job_scheduler.go index 9021ae5c..ac1eb9dc 100644 --- a/internal/maintenance/job_scheduler.go +++ b/internal/maintenance/job_scheduler.go @@ -14,6 +14,7 @@ import ( "github.com/riverqueue/river/rivershared/testsignal" "github.com/riverqueue/river/rivershared/util/randutil" "github.com/riverqueue/river/rivershared/util/serviceutil" + "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivershared/util/testutil" "github.com/riverqueue/river/rivershared/util/timeutil" "github.com/riverqueue/river/rivertype" @@ -35,11 +36,11 @@ func (ts *JobSchedulerTestSignals) Init(tb testutil.TestingTB) { ts.ScheduledBatch.Init(tb) } -type InsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) +type InsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error // NotifyInsert is a function to call to emit notifications for queues where // jobs were scheduled. -type NotifyInsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, queues []string) error +type NotifyInsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, queuesDeduped []string) error type JobSchedulerConfig struct { // Interval is the amount of time between periodic checks for jobs to @@ -186,7 +187,7 @@ func (s *JobScheduler) runOnce(ctx context.Context) (*schedulerRunOnceResult, er } if len(queues) > 0 { - if err := s.config.NotifyInsert(ctx, tx, queues); err != nil { + if err := s.config.NotifyInsert(ctx, tx, sliceutil.Uniq(queues)); err != nil { return 0, fmt.Errorf("error notifying insert: %w", err) } s.TestSignals.NotifiedQueues.Signal(queues) diff --git a/internal/maintenance/job_scheduler_test.go b/internal/maintenance/job_scheduler_test.go index 0c934778..5944d3d4 100644 --- a/internal/maintenance/job_scheduler_test.go +++ b/internal/maintenance/job_scheduler_test.go @@ -320,8 +320,8 @@ func TestJobScheduler(t *testing.T) { scheduler, _ := setup(t, &testOpts{exec: exec, schema: schema}) scheduler.config.Interval = time.Minute // should only trigger once for the initial run - scheduler.config.NotifyInsert = func(ctx context.Context, tx riverdriver.ExecutorTx, queues []string) error { - notifyCh <- queues + scheduler.config.NotifyInsert = func(ctx context.Context, tx riverdriver.ExecutorTx, queuesDeduped []string) error { + notifyCh <- queuesDeduped return nil } now := time.Now().UTC() @@ -362,10 +362,10 @@ func TestJobScheduler(t *testing.T) { require.NoError(t, scheduler.Start(ctx)) scheduler.TestSignals.ScheduledBatch.WaitOrTimeout() - expectedQueues := []string{"queue1", "queue2", "queue2", "queue3", "queue4"} + expectedQueues := []string{"queue1", "queue2", "queue3", "queue4"} - notifiedQueues := riversharedtest.WaitOrTimeout(t, notifyCh) - sort.Strings(notifiedQueues) - require.Equal(t, expectedQueues, notifiedQueues) + notifiedQueuesDeduped := riversharedtest.WaitOrTimeout(t, notifyCh) + sort.Strings(notifiedQueuesDeduped) + require.Equal(t, expectedQueues, notifiedQueuesDeduped) }) } diff --git a/internal/maintenance/periodic_job_enqueuer.go b/internal/maintenance/periodic_job_enqueuer.go index 90e5d6aa..fe11a834 100644 --- a/internal/maintenance/periodic_job_enqueuer.go +++ b/internal/maintenance/periodic_job_enqueuer.go @@ -342,8 +342,7 @@ func (s *PeriodicJobEnqueuer) insertBatch(ctx context.Context, insertParamsMany defer tx.Rollback(ctx) if len(insertParamsMany) > 0 { - _, err := s.Config.Insert(ctx, tx, insertParamsMany) - if err != nil { + if err := s.Config.Insert(ctx, tx, insertParamsMany); err != nil { s.Logger.ErrorContext(ctx, s.Name+": Error inserting periodic jobs", "error", err.Error(), "num_jobs", len(insertParamsMany)) return diff --git a/internal/maintenance/periodic_job_enqueuer_test.go b/internal/maintenance/periodic_job_enqueuer_test.go index 56bfbbd6..22a84707 100644 --- a/internal/maintenance/periodic_job_enqueuer_test.go +++ b/internal/maintenance/periodic_job_enqueuer_test.go @@ -79,22 +79,15 @@ func TestPeriodicJobEnqueuer(t *testing.T) { // A simplified version of `Client.insertMany` that only inserts jobs directly // via the driver instead of using the pilot. - makeInsertFunc := func(schema string) func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) { - return func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) { - results, err := tx.JobInsertFastMany(ctx, &riverdriver.JobInsertFastManyParams{ + makeInsertFunc := func(schema string) func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error { + return func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error { + _, err := tx.JobInsertFastMany(ctx, &riverdriver.JobInsertFastManyParams{ Jobs: sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams { return (*riverdriver.JobInsertFastParams)(params) }), Schema: schema, }) - if err != nil { - return nil, err - } - return sliceutil.Map(results, - func(result *riverdriver.JobInsertFastResult) *rivertype.JobInsertResult { - return (*rivertype.JobInsertResult)(result) - }, - ), nil + return err } } diff --git a/producer.go b/producer.go index f2417769..86fce3ad 100644 --- a/producer.go +++ b/producer.go @@ -41,21 +41,25 @@ const ( // Test-only properties. type producerTestSignals struct { - DeletedExpiredQueueRecords testsignal.TestSignal[struct{}] // notifies when the producer deletes expired queue records - MetadataChanged testsignal.TestSignal[struct{}] // notifies when the producer detects a metadata change - Paused testsignal.TestSignal[struct{}] // notifies when the producer is paused - PolledQueueConfig testsignal.TestSignal[struct{}] // notifies when the producer polls for queue settings - ReportedProducerStatus testsignal.TestSignal[struct{}] // notifies when the producer reports its own status - ReportedQueueStatus testsignal.TestSignal[struct{}] // notifies when the producer reports queue status - Resumed testsignal.TestSignal[struct{}] // notifies when the producer is resumed - StartedExecutors testsignal.TestSignal[struct{}] // notifies when runOnce finishes a pass + DeletedExpiredQueueRecords testsignal.TestSignal[struct{}] // notifies when the producer deletes expired queue records + JobFetchTriggered testsignal.TestSignal[struct{}] // notifies when the producer's fetch limiter is triggered via triggerJobFetch + MetadataChanged testsignal.TestSignal[struct{}] // notifies when the producer detects a metadata change + Paused testsignal.TestSignal[struct{}] // notifies when the producer is paused + PolledQueueConfig testsignal.TestSignal[struct{}] // notifies when the producer polls for queue settings + QueueControlEventTriggered testsignal.TestSignal[*controlEventPayload] // notifies when a queue control event is triggered via triggerQueueControlEvent + ReportedProducerStatus testsignal.TestSignal[struct{}] // notifies when the producer reports its own status + ReportedQueueStatus testsignal.TestSignal[struct{}] // notifies when the producer reports queue status + Resumed testsignal.TestSignal[struct{}] // notifies when the producer is resumed + StartedExecutors testsignal.TestSignal[struct{}] // notifies when runOnce finishes a pass } func (ts *producerTestSignals) Init(tb testutil.TestingTB) { ts.DeletedExpiredQueueRecords.Init(tb) + ts.JobFetchTriggered.Init(tb) ts.MetadataChanged.Init(tb) ts.Paused.Init(tb) ts.PolledQueueConfig.Init(tb) + ts.QueueControlEventTriggered.Init(tb) ts.ReportedQueueStatus.Init(tb) ts.ReportedProducerStatus.Init(tb) ts.Resumed.Init(tb) @@ -182,6 +186,7 @@ type producer struct { id atomic.Int64 // atomic because it's written at startup and read during shutdown exec riverdriver.Executor errorHandler jobexecutor.ErrorHandler + fetchLimiter *chanutil.DebouncedChan state riverpilot.ProducerState pilot riverpilot.Pilot workers *Workers @@ -324,7 +329,7 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { p.id.Store(id) // TODO: fetcher should have some jitter in it to avoid stampeding issues. - fetchLimiter := chanutil.NewDebouncedChan(fetchCtx, p.config.FetchCooldown, true) + p.fetchLimiter = chanutil.NewDebouncedChan(fetchCtx, p.config.FetchCooldown, true) var ( controlSub *notifier.Subscription @@ -343,7 +348,7 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { return } p.Logger.DebugContext(workCtx, p.Name+": Received insert notification", slog.String("queue", decoded.Queue)) - fetchLimiter.Call() + p.fetchLimiter.Call() } insertSub, err = p.config.Notifier.Listen(fetchCtx, notifier.NotificationTopicInsert, handleInsertNotification) if err != nil { @@ -395,7 +400,7 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { go p.pollForSettingChanges(subroutineCtx, &subroutineWG, initiallyPaused, initialMetadata) } - p.fetchAndRunLoop(fetchCtx, workCtx, fetchLimiter) + p.fetchAndRunLoop(fetchCtx, workCtx) p.Logger.Debug(p.Name+": Entering shutdown loop", slog.String("queue", p.config.Queue), slog.Int64("id", p.id.Load())) p.executorShutdownLoop() @@ -410,6 +415,29 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { return nil } +// TriggerJobFetch manually triggers the producer to perform a job fetch +// (although it's debounced, so it may not happen immediately if a fetch was +// performed very recently). This is used by clients using drivers that don't +// support listeners to wake a producer immediately after a job insert was known +// to be performed so the producer doesn't have to wait on polling. +func (p *producer) TriggerJobFetch() { + if p.fetchLimiter != nil { + p.fetchLimiter.Call() + } + p.testSignals.JobFetchTriggered.Signal(struct{}{}) +} + +// TriggerQueueControlEvent manually injects a queue control event into the +// producer's queue control channel as if it'd been received through +// listen/notify. This is used by clients using drivers that don't support +// listeners to wake a producer immediately after a queue control event was +// known to be performed so the producer doesn't have to wait on polling. +func (p *producer) TriggerQueueControlEvent(controlEvent *controlEventPayload) { + p.queueControlCh <- controlEvent + p.testSignals.QueueControlEventTriggered.Signal(controlEvent) + +} + type controlAction string const ( @@ -475,10 +503,10 @@ func (p *producer) handleControlNotification(workCtx context.Context) func(notif } } -func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimiter *chanutil.DebouncedChan) { +func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context) { // Prime the fetchLimiter so we can make an initial fetch without waiting for // an insert notification or a fetch poll. - fetchLimiter.Call() + p.fetchLimiter.Call() fetchPollTimer := time.NewTimer(p.config.FetchPollInterval) go func() { @@ -491,7 +519,7 @@ func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimit } return case <-fetchPollTimer.C: - fetchLimiter.Call() + p.fetchLimiter.Call() fetchPollTimer.Reset(p.config.FetchPollInterval) } } @@ -533,7 +561,7 @@ func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimit } p.paused = false p.Logger.DebugContext(workCtx, p.Name+": Resumed", slog.String("queue", p.config.Queue), slog.String("queue_in_message", msg.Queue)) - fetchLimiter.Call() // try another fetch because more jobs may be available to run which were gated behind the paused queue + p.fetchLimiter.Call() // try another fetch because more jobs may be available to run which were gated behind the paused queue p.testSignals.Resumed.Signal(struct{}{}) if p.config.QueueEventCallback != nil { p.config.QueueEventCallback(&Event{Kind: EventKindQueueResumed, Queue: &rivertype.Queue{Name: p.config.Queue}}) @@ -543,7 +571,7 @@ func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimit } case jobID := <-p.cancelCh: p.maybeCancelJob(jobID) - case <-fetchLimiter.C(): + case <-p.fetchLimiter.C(): p.innerFetchLoop(workCtx, fetchResultCh) // Ensure we can't start another fetch when fetchCtx is done, even if // the fetchLimiter is also ready to fire: @@ -560,7 +588,7 @@ func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimit // more aggressive triggering the fetch limiter now that we have a slot // available. p.fetchWhenSlotsAreAvailable = false - fetchLimiter.Call() + p.fetchLimiter.Call() } } } diff --git a/rivershared/riverpilot/pilot.go b/rivershared/riverpilot/pilot.go index c91e32e8..a6cee157 100644 --- a/rivershared/riverpilot/pilot.go +++ b/rivershared/riverpilot/pilot.go @@ -25,11 +25,11 @@ type Pilot interface { JobInsertMany( ctx context.Context, - tx riverdriver.ExecutorTx, + execTx riverdriver.ExecutorTx, params *riverdriver.JobInsertFastManyParams, ) ([]*riverdriver.JobInsertFastResult, error) - JobSetStateIfRunningMany(ctx context.Context, tx riverdriver.ExecutorTx, params *riverdriver.JobSetStateIfRunningManyParams) ([]*rivertype.JobRow, error) + JobSetStateIfRunningMany(ctx context.Context, execTx riverdriver.ExecutorTx, params *riverdriver.JobSetStateIfRunningManyParams) ([]*rivertype.JobRow, error) PilotInit(archetype *baseservice.Archetype) diff --git a/rivershared/riverpilot/standard.go b/rivershared/riverpilot/standard.go index 30fe773e..ea12b207 100644 --- a/rivershared/riverpilot/standard.go +++ b/rivershared/riverpilot/standard.go @@ -22,14 +22,14 @@ func (p *StandardPilot) JobGetAvailable(ctx context.Context, exec riverdriver.Ex func (p *StandardPilot) JobInsertMany( ctx context.Context, - tx riverdriver.ExecutorTx, + execTx riverdriver.ExecutorTx, params *riverdriver.JobInsertFastManyParams, ) ([]*riverdriver.JobInsertFastResult, error) { - return tx.JobInsertFastMany(ctx, params) + return execTx.JobInsertFastMany(ctx, params) } -func (p *StandardPilot) JobSetStateIfRunningMany(ctx context.Context, tx riverdriver.ExecutorTx, params *riverdriver.JobSetStateIfRunningManyParams) ([]*rivertype.JobRow, error) { - return tx.JobSetStateIfRunningMany(ctx, params) +func (p *StandardPilot) JobSetStateIfRunningMany(ctx context.Context, execTx riverdriver.ExecutorTx, params *riverdriver.JobSetStateIfRunningManyParams) ([]*rivertype.JobRow, error) { + return execTx.JobSetStateIfRunningMany(ctx, params) } func (p *StandardPilot) PilotInit(archetype *baseservice.Archetype) { diff --git a/rivershared/testsignal/test_signal.go b/rivershared/testsignal/test_signal.go index b8a20b03..845502a1 100644 --- a/rivershared/testsignal/test_signal.go +++ b/rivershared/testsignal/test_signal.go @@ -1,7 +1,6 @@ package testsignal import ( - "fmt" "time" "github.com/riverqueue/river/rivershared/riversharedtest" @@ -38,6 +37,20 @@ func (s *TestSignal[T]) Init(tb testutil.TestingTB) { s.tb = tb } +// RequireEmpty requires that the test signal be empty (i.e. have not received +// any values). +func (s *TestSignal[T]) RequireEmpty() { + if s.internalChan == nil { + panic("test only signal is not initialized; called outside of tests?") + } + + select { + case val := <-s.internalChan: + s.tb.Errorf("test signal should be empty, but wasn't\ngot value: %v\n", val) + default: + } +} + // Signal signals the test signal. In production where the signal hasn't been // initialized, this no ops harmlessly. In tests, the value is written to an // internal asynchronous channel which can be waited with WaitOrTimeout. @@ -52,7 +65,6 @@ func (s *TestSignal[T]) Signal(val T) { case s.internalChan <- val: default: s.tb.Errorf("test only signal channel is full") - s.tb.FailNow() } } @@ -77,9 +89,12 @@ func (s *TestSignal[T]) WaitOrTimeout() T { timeout := riversharedtest.WaitTimeout() select { - case value := <-s.internalChan: - return value + case val := <-s.internalChan: + return val case <-time.After(timeout): - panic(fmt.Sprintf("timed out waiting on test signal after %s", timeout)) + s.tb.Errorf("timed out waiting on test signal after %s", timeout) } + + var val T + return val } diff --git a/rivershared/testsignal/test_signal_test.go b/rivershared/testsignal/test_signal_test.go index 2f1bd656..6e8e22d1 100644 --- a/rivershared/testsignal/test_signal_test.go +++ b/rivershared/testsignal/test_signal_test.go @@ -49,6 +49,27 @@ func TestTestSignal(t *testing.T) { } }) + t.Run("RequireEmpty", func(t *testing.T) { + t.Parallel() + + signal := TestSignal[struct{}]{} + + require.PanicsWithValue(t, "test only signal is not initialized; called outside of tests?", func() { + signal.RequireEmpty() + }) + + mockT := testutil.NewMockT(t) + signal.Init(mockT) + + signal.RequireEmpty() // succeeds + + signal.Signal(struct{}{}) + + signal.RequireEmpty() + require.True(t, mockT.Failed) + require.Equal(t, "test signal should be empty, but wasn't\ngot value: {}\n\n", mockT.LogOutput()) + }) + t.Run("WaitC", func(t *testing.T) { t.Parallel() @@ -74,6 +95,11 @@ func TestTestSignal(t *testing.T) { t.Parallel() signal := TestSignal[struct{}]{} + + require.PanicsWithValue(t, "test only signal is not initialized; called outside of tests?", func() { + signal.WaitOrTimeout() + }) + signal.Init(t) signal.Signal(struct{}{}) diff --git a/rivershared/util/sliceutil/slice_util.go b/rivershared/util/sliceutil/slice_util.go index a0b7262a..945de5af 100644 --- a/rivershared/util/sliceutil/slice_util.go +++ b/rivershared/util/sliceutil/slice_util.go @@ -54,3 +54,21 @@ func Map[T any, R any](collection []T, mapFunc func(T) R) []R { return result } + +// Uniq returns a duplicate-free version of an array, in which only the first occurrence of each element is kept. +// The order of result values is determined by the order they occur in the array. +func Uniq[T comparable](collection []T) []T { + result := make([]T, 0, len(collection)) + seen := make(map[T]struct{}, len(collection)) + + for _, item := range collection { + if _, ok := seen[item]; ok { + continue + } + + seen[item] = struct{}{} + result = append(result, item) + } + + return result +} diff --git a/rivershared/util/sliceutil/slice_util_test.go b/rivershared/util/sliceutil/slice_util_test.go index 9026df7e..bdac4bb1 100644 --- a/rivershared/util/sliceutil/slice_util_test.go +++ b/rivershared/util/sliceutil/slice_util_test.go @@ -89,3 +89,12 @@ func TestMap(t *testing.T) { require.Equal(t, []string{"Hello", "Hello", "Hello", "Hello"}, result1) require.Equal(t, []string{"1", "2", "3", "4"}, result2) } + +func TestUniq(t *testing.T) { + t.Parallel() + + result1 := Uniq([]int{1, 2, 2, 1}) + + require.Len(t, result1, 2) + require.Equal(t, []int{1, 2}, result1) +}