From 27fa8e162e7883a941f55ddf91fc36dbc7b5ccd1 Mon Sep 17 00:00:00 2001 From: Brandur Date: Sun, 25 May 2025 14:10:30 -0600 Subject: [PATCH] Out of transaction notification on job insert and queue changes Here, we try to implement a system that'll make drivers that don't support listen/notify like SQLite more responsive by default (i.e. without having to crank their fetch intervals down to miniscule numbers and have them fetch with hyper-frequency) so they respond to a new job or queue being paused quickly. Currently, drivers with listeners send out notifications which are then received by the producer so that it can enact work. Drivers without listeners are left out in the cold though because they don't have any good listen/notify alternative. This proposal takes non-tx functions like `Client.Insert` or `Client.QueuePause` and makes a special call to the producer to notify it of what happened so it can respond quickly without waiting for a poll. The approach of course has a couple limitations: * It only works for non-tx functions. If we were to notify the producer within a transaction, the producer would wake up and find nothing to do because the relevant data hasn't been persisted yet. This is taken care of on the listener drivers because listen/notify is transaction friendly. * Only works within the same process. A client obviously can't notify a producer in a different Go process, so those will still have to poll. Still, if even one producer can get a notification about an inserted job immediately, that job still gets worked faster. This improvement is also aimed pretty specifically aimed at our test suite, where there's only ever one process running. Being able to notify a producer within the same process means that we don't need to put an aggressive poll loop into ever example or client test that uses SQLite. Longer term, I'd like to replace this with something else, like a general purpose unlogged table that could process notifications in a transaction in a similar way to how listen/notify does it (although I'd have to do a little more work on this, I'm not sure it'd work), but I figure that this might be good enough for the time being. This doesn't change the public API at all, so it should be fairly easy to replace all changes made here with something more suitable if it comes up. --- CHANGELOG.md | 1 + client.go | 292 +++++++++++------- client_test.go | 201 +++++++++++- example_sqlite_test.go | 5 +- internal/maintenance/job_scheduler.go | 7 +- internal/maintenance/job_scheduler_test.go | 12 +- internal/maintenance/periodic_job_enqueuer.go | 3 +- .../maintenance/periodic_job_enqueuer_test.go | 15 +- producer.go | 62 +++- rivershared/riverpilot/pilot.go | 4 +- rivershared/riverpilot/standard.go | 8 +- rivershared/testsignal/test_signal.go | 25 +- rivershared/testsignal/test_signal_test.go | 26 ++ rivershared/util/sliceutil/slice_util.go | 18 ++ rivershared/util/sliceutil/slice_util_test.go | 9 + 15 files changed, 522 insertions(+), 166 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 213f2fd9..33031e34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - CLI `river migrate-get` now takes a `--schema` option to inject a custom schema into dumped migrations and schema comments are hidden if `--schema` option isn't provided. [PR #903](https://github.com/riverqueue/river/pull/903). - Added `riverlog.NewMiddlewareCustomContext` that makes the use of `riverlog` job-persisted logging possible with non-slog loggers. [PR #919](https://github.com/riverqueue/river/pull/919). - Added `RequireInsertedOpts.Schema`, allowing an explicit schema to be set when asserting on job inserts with `rivertest`. [PR #926](https://github.com/riverqueue/river/pull/926). +- When using a driver that doesn't support listen/notify, producers within same process are notified immediately of new job inserts and queue changes (e.g. pause/resume) without having to poll when non-transactional variants are used (i.e. `Insert` instead of `InsertTx`). [PR #928](https://github.com/riverqueue/river/pull/928). - Added `JobListParams.Where`, which provides an escape hatch for job listing that runs arbitrary SQL with named parameters. [PR #933](https://github.com/riverqueue/river/pull/933). ### Changed diff --git a/client.go b/client.go index 25b4d4f1..8a432444 100644 --- a/client.go +++ b/client.go @@ -23,6 +23,7 @@ import ( "github.com/riverqueue/river/internal/notifier" "github.com/riverqueue/river/internal/notifylimiter" "github.com/riverqueue/river/internal/rivercommon" + "github.com/riverqueue/river/internal/util/dbutil" "github.com/riverqueue/river/internal/workunit" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/baseservice" @@ -823,7 +824,10 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client { periodicJobEnqueuer := maintenance.NewPeriodicJobEnqueuer(archetype, &maintenance.PeriodicJobEnqueuerConfig{ AdvisoryLockPrefix: config.AdvisoryLockPrefix, - Insert: client.insertMany, + Insert: func(ctx context.Context, execTx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error { + _, err := client.insertMany(ctx, execTx, insertParams) + return err + }, }, driver.GetExecutor()) maintenanceServices = append(maintenanceServices, periodicJobEnqueuer) client.testSignals.periodicJobEnqueuer = &periodicJobEnqueuer.TestSignals @@ -1272,7 +1276,18 @@ func (c *Client[TTx]) Driver() riverdriver.Driver[TTx] { // Returns the up-to-date JobRow for the specified jobID if it exists. Returns // ErrNotFound if the job doesn't exist. func (c *Client[TTx]) JobCancel(ctx context.Context, jobID int64) (*rivertype.JobRow, error) { - return c.jobCancel(ctx, c.driver.GetExecutor(), jobID) + job, err := c.jobCancel(ctx, c.driver.GetExecutor(), jobID) + if err != nil { + return nil, err + } + + c.notifyProducerWithoutListenerQueueControlEvent(job.Queue, &controlEventPayload{ + Action: controlActionCancel, + JobID: job.ID, + Queue: job.Queue, + }) + + return job, nil } // JobCancelTx cancels the job with the given ID within the specified @@ -1535,21 +1550,16 @@ func (c *Client[TTx]) Insert(ctx context.Context, args JobArgs, opts *InsertOpts return nil, errNoDriverDBPool } - tx, err := c.driver.GetExecutor().Begin(ctx) + res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) (*insertManySharedResult, error) { + return c.validateParamsAndInsertMany(ctx, execTx, []InsertManyParams{{Args: args, InsertOpts: opts}}) + }) if err != nil { return nil, err } - defer tx.Rollback(ctx) - inserted, err := c.insert(ctx, tx, args, opts) - if err != nil { - return nil, err - } + c.notifyProducerWithoutListenerJobFetch(res.QueuesDeduped) - if err := tx.Commit(ctx); err != nil { - return nil, err - } - return inserted, nil + return res.InsertResults[0], nil } // InsertTx inserts a new job with the provided args on the given transaction. @@ -1570,17 +1580,11 @@ func (c *Client[TTx]) Insert(ctx context.Context, args JobArgs, opts *InsertOpts // transactions, the job will not be worked until the transaction has committed, // and if the transaction rolls back, so too is the inserted job. func (c *Client[TTx]) InsertTx(ctx context.Context, tx TTx, args JobArgs, opts *InsertOpts) (*rivertype.JobInsertResult, error) { - return c.insert(ctx, c.driver.UnwrapExecutor(tx), args, opts) -} - -func (c *Client[TTx]) insert(ctx context.Context, tx riverdriver.ExecutorTx, args JobArgs, opts *InsertOpts) (*rivertype.JobInsertResult, error) { - params := []InsertManyParams{{Args: args, InsertOpts: opts}} - results, err := c.validateParamsAndInsertMany(ctx, tx, params) + res, err := c.validateParamsAndInsertMany(ctx, c.driver.UnwrapExecutor(tx), []InsertManyParams{{Args: args, InsertOpts: opts}}) if err != nil { return nil, err } - - return results[0], nil + return res.InsertResults[0], nil } // InsertManyParams encapsulates a single job combined with insert options for @@ -1612,21 +1616,16 @@ func (c *Client[TTx]) InsertMany(ctx context.Context, params []InsertManyParams) return nil, errNoDriverDBPool } - tx, err := c.driver.GetExecutor().Begin(ctx) + res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) (*insertManySharedResult, error) { + return c.validateParamsAndInsertMany(ctx, execTx, params) + }) if err != nil { return nil, err } - defer tx.Rollback(ctx) - inserted, err := c.validateParamsAndInsertMany(ctx, tx, params) - if err != nil { - return nil, err - } + c.notifyProducerWithoutListenerJobFetch(res.QueuesDeduped) - if err := tx.Commit(ctx); err != nil { - return nil, err - } - return inserted, nil + return res.InsertResults, nil } // InsertManyTx inserts many jobs at once. Each job is inserted as an @@ -1648,28 +1647,31 @@ func (c *Client[TTx]) InsertMany(ctx context.Context, params []InsertManyParams) // changes. An inserted job isn't visible to be worked until the transaction // commits, and if the transaction rolls back, so too is the inserted job. func (c *Client[TTx]) InsertManyTx(ctx context.Context, tx TTx, params []InsertManyParams) ([]*rivertype.JobInsertResult, error) { - exec := c.driver.UnwrapExecutor(tx) - return c.validateParamsAndInsertMany(ctx, exec, params) + res, err := c.validateParamsAndInsertMany(ctx, c.driver.UnwrapExecutor(tx), params) + if err != nil { + return nil, err + } + return res.InsertResults, nil } // validateParamsAndInsertMany is a helper method that wraps the insertMany // method to provide param validation and conversion prior to calling the actual // insertMany method. This allows insertMany to be reused by the // PeriodicJobEnqueuer which cannot reference top-level river package types. -func (c *Client[TTx]) validateParamsAndInsertMany(ctx context.Context, tx riverdriver.ExecutorTx, params []InsertManyParams) ([]*rivertype.JobInsertResult, error) { +func (c *Client[TTx]) validateParamsAndInsertMany(ctx context.Context, execTx riverdriver.ExecutorTx, params []InsertManyParams) (*insertManySharedResult, error) { insertParams, err := c.insertManyParams(params) if err != nil { return nil, err } - return c.insertMany(ctx, tx, insertParams) + return c.insertMany(ctx, execTx, insertParams) } // insertMany is a shared code path for InsertMany and InsertManyTx, also used // by the PeriodicJobEnqueuer. -func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) { - return c.insertManyShared(ctx, tx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) { - results, err := c.pilot.JobInsertMany(ctx, tx, &riverdriver.JobInsertFastManyParams{ +func (c *Client[TTx]) insertMany(ctx context.Context, execTx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) (*insertManySharedResult, error) { + return c.insertManyShared(ctx, execTx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) { + results, err := c.pilot.JobInsertMany(ctx, execTx, &riverdriver.JobInsertFastManyParams{ Jobs: insertParams, Schema: c.config.Schema, }) @@ -1685,6 +1687,11 @@ func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx, }) } +type insertManySharedResult struct { + InsertResults []*rivertype.JobInsertResult + QueuesDeduped []string +} + // The shared code path for all Insert and InsertMany methods. It takes a // function that executes the actual insert operation and allows for different // implementations of the insert query to be passed in, each mapping their @@ -1694,7 +1701,9 @@ func (c *Client[TTx]) insertManyShared( tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams, execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error), -) ([]*rivertype.JobInsertResult, error) { +) (*insertManySharedResult, error) { + var queuesDeduped []string + doInner := func(ctx context.Context) ([]*rivertype.JobInsertResult, error) { for _, params := range insertParams { for _, hook := range append( @@ -1710,9 +1719,10 @@ func (c *Client[TTx]) insertManyShared( finalInsertParams := sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams { return (*riverdriver.JobInsertFastParams)(params) }) - results, err := execute(ctx, finalInsertParams) + + insertResults, err := execute(ctx, finalInsertParams) if err != nil { - return results, err + return insertResults, err } queues := make([]string, 0, 10) @@ -1721,10 +1731,14 @@ func (c *Client[TTx]) insertManyShared( queues = append(queues, params.Queue) } } - if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil { + + queuesDeduped = sliceutil.Uniq(queues) + + if err = c.maybeNotifyInsertForQueues(ctx, tx, queuesDeduped); err != nil { return nil, err } - return results, nil + + return insertResults, nil } jobInsertMiddleware := c.middlewareLookupGlobal.ByMiddlewareKind(middlewarelookup.MiddlewareKindJobInsert) @@ -1740,7 +1754,15 @@ func (c *Client[TTx]) insertManyShared( } } - return doInner(ctx) + insertResults, err := doInner(ctx) + if err != nil { + return nil, err + } + + return &insertManySharedResult{ + InsertResults: insertResults, + QueuesDeduped: queuesDeduped, + }, nil } // Validates input parameters for a batch insert operation and generates a set @@ -1767,6 +1789,25 @@ func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*rivertype. return insertParams, nil } +// Notifies an internal producer of new jobs being queued for work. Only +// invoked if the client's driver doesn't support a listener. If a listener is +// supported, job notifications go out via listen/notify instead. +// +// Should only ever be invoked *outside* a transaction. If invoked within a +// transaction, the producer wouldn't yet be able to access the new jobs that +// triggered the notification because they're not committed yet. +func (c *Client[TTx]) notifyProducerWithoutListenerJobFetch(queuesDeduped []string) { + if c.driver.SupportsListener() || len(c.producersByQueueName) < 1 { + return + } + + for _, queue := range queuesDeduped { + if producer, ok := c.producersByQueueName[queue]; ok { + producer.TriggerJobFetch() + } + } +} + // InsertManyFast inserts many jobs at once using Postgres' `COPY FROM` mechanism, // making the operation quite fast and memory efficient. Each job is inserted as // an InsertManyParams tuple, which takes job args along with an optional set of @@ -1791,20 +1832,16 @@ func (c *Client[TTx]) InsertManyFast(ctx context.Context, params []InsertManyPar } // Wrap in a transaction in case we need to notify about inserts. - tx, err := c.driver.GetExecutor().Begin(ctx) + res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) (*insertManySharedResult, error) { + return c.insertManyFast(ctx, execTx, params) + }) if err != nil { return 0, err } - defer tx.Rollback(ctx) - inserted, err := c.insertManyFast(ctx, tx, params) - if err != nil { - return 0, err - } - if err := tx.Commit(ctx); err != nil { - return 0, err - } - return inserted, nil + c.notifyProducerWithoutListenerJobFetch(res.QueuesDeduped) + + return len(res.InsertResults), nil } // InsertManyTx inserts many jobs at once using Postgres' `COPY FROM` mechanism, @@ -1831,17 +1868,20 @@ func (c *Client[TTx]) InsertManyFast(ctx context.Context, params []InsertManyPar // a unique constraint is violated, the operation will fail and no jobs will be // inserted. func (c *Client[TTx]) InsertManyFastTx(ctx context.Context, tx TTx, params []InsertManyParams) (int, error) { - exec := c.driver.UnwrapExecutor(tx) - return c.insertManyFast(ctx, exec, params) + res, err := c.insertManyFast(ctx, c.driver.UnwrapExecutor(tx), params) + if err != nil { + return 0, err + } + return len(res.InsertResults), nil } -func (c *Client[TTx]) insertManyFast(ctx context.Context, execTx riverdriver.ExecutorTx, params []InsertManyParams) (int, error) { +func (c *Client[TTx]) insertManyFast(ctx context.Context, execTx riverdriver.ExecutorTx, params []InsertManyParams) (*insertManySharedResult, error) { insertParams, err := c.insertManyParams(params) if err != nil { - return 0, err + return nil, err } - results, err := c.insertManyShared(ctx, execTx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) { + return c.insertManyShared(ctx, execTx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) { count, err := execTx.JobInsertFastManyNoReturning(ctx, &riverdriver.JobInsertFastManyParams{ Jobs: insertParams, Schema: c.config.Schema, @@ -1851,34 +1891,25 @@ func (c *Client[TTx]) insertManyFast(ctx context.Context, execTx riverdriver.Exe } return make([]*rivertype.JobInsertResult, count), nil }) - if err != nil { - return 0, err - } - - return len(results), nil } // Notify the given queues that new jobs are available. The queues list will be // deduplicated and each will be checked to see if it is due for an insert // notification from this client. -func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdriver.ExecutorTx, queues []string) error { - if len(queues) < 1 { +func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdriver.ExecutorTx, queuesDeduped []string) error { + if len(queuesDeduped) < 1 { return nil } - queueMap := make(map[string]struct{}) - queuesDeduped := make([]string, 0, len(queues)) - payloads := make([]string, 0, len(queues)) - - for _, queue := range queues { - if _, ok := queueMap[queue]; ok { - continue - } + var ( + payloads = make([]string, 0, len(queuesDeduped)) + queuesTriggered = make([]string, 0, len(queuesDeduped)) + ) - queueMap[queue] = struct{}{} + for _, queue := range queuesDeduped { if c.insertNotifyLimiter.ShouldTrigger(queue) { payloads = append(payloads, fmt.Sprintf("{\"queue\": %q}", queue)) - queuesDeduped = append(queuesDeduped, queue) + queuesTriggered = append(queuesTriggered, queue) } } @@ -1896,7 +1927,7 @@ func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdr c.baseService.Logger.ErrorContext( ctx, c.baseService.Name+": Failed to send job insert notification", - slog.String("queues", strings.Join(queuesDeduped, ",")), + slog.String("queues", strings.Join(queuesTriggered, ",")), slog.String("err", err.Error()), ) return err @@ -1907,7 +1938,7 @@ func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdr } // emit a notification about a queue being paused or resumed. -func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriver.ExecutorTx, action controlAction, queue string, opts *QueuePauseOpts) error { +func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriver.ExecutorTx, action controlAction, queue string, opts *QueuePauseOpts) (*controlEventPayload, error) { c.baseService.Logger.DebugContext(ctx, c.baseService.Name+": Notifying about queue state change", slog.String("action", string(action)), @@ -1915,9 +1946,11 @@ func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriv slog.String("opts", fmt.Sprintf("%+v", opts)), ) - payload, err := json.Marshal(controlEventPayload{Action: action, Queue: queue}) + controlEvent := &controlEventPayload{Action: action, Queue: queue} + + payload, err := json.Marshal(controlEvent) if err != nil { - return err + return nil, err } if c.driver.SupportsListenNotify() { @@ -1932,11 +1965,11 @@ func (c *Client[TTx]) notifyQueuePauseOrResume(ctx context.Context, tx riverdriv c.baseService.Name+": Failed to send queue state change notification", slog.String("err", err.Error()), ) - return err + return nil, err } } - return nil + return controlEvent, nil } // Validates job args prior to insertion. Currently, verifies that a worker to @@ -2221,11 +2254,18 @@ func (c *Client[TTx]) QueuePause(ctx context.Context, name string, opts *QueuePa return err } - if err := c.notifyQueuePauseOrResume(ctx, tx, controlActionPause, name, opts); err != nil { + controlEvent, err := c.notifyQueuePauseOrResume(ctx, tx, controlActionPause, name, opts) + if err != nil { + return err + } + + if err = tx.Commit(ctx); err != nil { return err } - return tx.Commit(ctx) + c.notifyProducerWithoutListenerQueueControlEvent(name, controlEvent) + + return nil } // QueuePauseTx pauses the queue with the given name. When a queue is paused, @@ -2249,7 +2289,7 @@ func (c *Client[TTx]) QueuePauseTx(ctx context.Context, tx TTx, name string, opt return err } - if err := c.notifyQueuePauseOrResume(ctx, executorTx, controlActionPause, name, opts); err != nil { + if _, err := c.notifyQueuePauseOrResume(ctx, executorTx, controlActionPause, name, opts); err != nil { return err } @@ -2282,11 +2322,18 @@ func (c *Client[TTx]) QueueResume(ctx context.Context, name string, opts *QueueP return err } - if err := c.notifyQueuePauseOrResume(ctx, tx, controlActionResume, name, opts); err != nil { + controlEvent, err := c.notifyQueuePauseOrResume(ctx, tx, controlActionResume, name, opts) + if err != nil { return err } - return tx.Commit(ctx) + if err = tx.Commit(ctx); err != nil { + return err + } + + c.notifyProducerWithoutListenerQueueControlEvent(name, controlEvent) + + return nil } // QueueResume resumes the queue with the given name. If the queue was @@ -2311,7 +2358,7 @@ func (c *Client[TTx]) QueueResumeTx(ctx context.Context, tx TTx, name string, op return err } - if err := c.notifyQueuePauseOrResume(ctx, executorTx, controlActionResume, name, opts); err != nil { + if _, err := c.notifyQueuePauseOrResume(ctx, executorTx, controlActionResume, name, opts); err != nil { return err } @@ -2334,7 +2381,7 @@ func (c *Client[TTx]) QueueUpdate(ctx context.Context, name string, params *Queu } defer tx.Rollback(ctx) - queue, err := c.queueUpdate(ctx, tx, name, params) + queue, controlEvent, err := c.queueUpdate(ctx, tx, name, params) if err != nil { return nil, err } @@ -2343,18 +2390,39 @@ func (c *Client[TTx]) QueueUpdate(ctx context.Context, name string, params *Queu return nil, err } + c.notifyProducerWithoutListenerQueueControlEvent(name, controlEvent) + return queue, nil } // QueueUpdateTx updates a queue's settings in the database. These settings // override the settings in the client (if applied). func (c *Client[TTx]) QueueUpdateTx(ctx context.Context, tx TTx, name string, params *QueueUpdateParams) (*rivertype.Queue, error) { - executorTx := c.driver.UnwrapExecutor(tx) + queue, _, err := c.queueUpdate(ctx, c.driver.UnwrapExecutor(tx), name, params) + if err != nil { + return nil, err + } + return queue, nil +} - return c.queueUpdate(ctx, executorTx, name, params) +// Notifies an internal producer of a queue control event like pause/resume. +// Only invoked if the client's driver doesn't support a listener. If a listener +// is supported, control events go out via listen/notify instead. +// +// Should only ever be invoked *outside* a transaction. If invoked within a +// transaction, the producer wouldn't yet be able to access the state that +// triggered the notification because it's not committed yet. +func (c *Client[TTx]) notifyProducerWithoutListenerQueueControlEvent(queue string, controlEvent *controlEventPayload) { + if c.driver.SupportsListener() || len(c.producersByQueueName) < 1 { + return + } + + if producer, ok := c.producersByQueueName[queue]; ok { + producer.TriggerQueueControlEvent(controlEvent) + } } -func (c *Client[TTx]) queueUpdate(ctx context.Context, executorTx riverdriver.ExecutorTx, name string, params *QueueUpdateParams) (*rivertype.Queue, error) { +func (c *Client[TTx]) queueUpdate(ctx context.Context, executorTx riverdriver.ExecutorTx, name string, params *QueueUpdateParams) (*rivertype.Queue, *controlEventPayload, error) { updateMetadata := len(params.Metadata) > 0 queue, err := executorTx.QueueUpdate(ctx, &riverdriver.QueueUpdateParams{ @@ -2364,31 +2432,35 @@ func (c *Client[TTx]) queueUpdate(ctx context.Context, executorTx riverdriver.Ex Schema: c.config.Schema, }) if err != nil { - return nil, err + return nil, nil, err } - if updateMetadata { - payload, err := json.Marshal(controlEventPayload{ - Action: controlActionMetadataChanged, - Metadata: params.Metadata, - Queue: queue.Name, - }) - if err != nil { - return nil, err - } + if !updateMetadata { + return queue, nil, err + } - if c.driver.SupportsListenNotify() { - if err := executorTx.NotifyMany(ctx, &riverdriver.NotifyManyParams{ - Payload: []string{string(payload)}, - Schema: c.config.Schema, - Topic: string(notifier.NotificationTopicControl), - }); err != nil { - return nil, err - } + controlEvent := &controlEventPayload{ + Action: controlActionMetadataChanged, + Metadata: params.Metadata, + Queue: queue.Name, + } + + payload, err := json.Marshal(controlEvent) + if err != nil { + return nil, nil, err + } + + if c.driver.SupportsListenNotify() { + if err := executorTx.NotifyMany(ctx, &riverdriver.NotifyManyParams{ + Payload: []string{string(payload)}, + Schema: c.config.Schema, + Topic: string(notifier.NotificationTopicControl), + }); err != nil { + return nil, nil, err } } - return queue, nil + return queue, controlEvent, nil } // QueueBundle is a bundle for adding additional queues. It's made accessible diff --git a/client_test.go b/client_test.go index 783ed24b..c18267ba 100644 --- a/client_test.go +++ b/client_test.go @@ -36,6 +36,7 @@ import ( "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverdatabasesql" "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/riverdriver/riversqlite" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/riversharedtest" "github.com/riverqueue/river/rivershared/startstoptest" @@ -134,7 +135,6 @@ func newTestConfig(t *testing.T, schema string) *Config { }, TestOnly: true, // disables staggered start in maintenance services Workers: workers, - queuePollInterval: 50 * time.Millisecond, schedulerInterval: riverinternaltest.SchedulerShortInterval, } } @@ -608,6 +608,69 @@ func Test_Client(t *testing.T) { require.NoError(t, err) }) + t.Run("CancelProducerControlEventSent", func(t *testing.T) { + t.Parallel() + + var ( + driver = riversqlite.New(nil) + schema = riverdbtest.TestSchema(ctx, t, driver, &riverdbtest.TestSchemaOpts{ + ProcurePool: func(ctx context.Context, schema string) (any, string) { + return riversharedtest.DBPoolSQLite(ctx, t, schema), "" // could also be `main` instead of empty string + }, + }) + config = newTestConfig(t, schema) + ) + + client, err := NewClient(driver, config) + require.NoError(t, err) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + type JobArgs struct { + JobArgsReflectKind[JobArgs] + } + + AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { + return nil + })) + + startClient(ctx, t, client) + + insertRes, err := client.Insert(ctx, &JobArgs{}, nil) + require.NoError(t, err) + + _, err = client.JobCancel(ctx, insertRes.Job.ID) + require.NoError(t, err) + + controlEvent := client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.WaitOrTimeout() + require.NotNil(t, controlEvent) + require.Equal(t, controlActionCancel, controlEvent.Action) + }) + + t.Run("CancelProducerControlEventNotSent", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + type JobArgs struct { + JobArgsReflectKind[JobArgs] + } + + AddWorker(client.config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { + return nil + })) + + startClient(ctx, t, client) + + insertRes, err := client.Insert(ctx, &JobArgs{}, nil) + require.NoError(t, err) + + _, err = client.JobCancel(ctx, insertRes.Job.ID) + require.NoError(t, err) + + client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.RequireEmpty() + }) + t.Run("AlternateSchema", func(t *testing.T) { t.Parallel() @@ -1161,6 +1224,55 @@ func Test_Client(t *testing.T) { } }) + t.Run("QueuePauseAndResumeProducerControlEventSent", func(t *testing.T) { + t.Parallel() + + var ( + driver = riversqlite.New(nil) + schema = riverdbtest.TestSchema(ctx, t, driver, &riverdbtest.TestSchemaOpts{ + ProcurePool: func(ctx context.Context, schema string) (any, string) { + return riversharedtest.DBPoolSQLite(ctx, t, schema), "" // could also be `main` instead of empty string + }, + }) + config = newTestConfig(t, schema) + ) + + client, err := NewClient(driver, config) + require.NoError(t, err) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + require.NoError(t, client.QueuePause(ctx, QueueDefault, nil)) + + controlEvent := client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.WaitOrTimeout() + require.NotNil(t, controlEvent) + require.Equal(t, controlActionPause, controlEvent.Action) + + require.NoError(t, client.QueueResume(ctx, QueueDefault, nil)) + + controlEvent = client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.WaitOrTimeout() + require.NotNil(t, controlEvent) + require.Equal(t, controlActionResume, controlEvent.Action) + }) + + t.Run("QueuePauseAndResumeProducerControlEventNotSent", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + require.NoError(t, client.QueuePause(ctx, QueueDefault, nil)) + + client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.RequireEmpty() + + require.NoError(t, client.QueueResume(ctx, QueueDefault, nil)) + + client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.RequireEmpty() + }) + t.Run("PollOnlyDriver", func(t *testing.T) { t.Parallel() @@ -1172,7 +1284,6 @@ func Test_Client(t *testing.T) { client, err := NewClient(riverdatabasesql.New(stdPool), config) require.NoError(t, err) - client.testSignals.Init(t) // Notifier should not have been initialized at all. @@ -2018,6 +2129,46 @@ func Test_Client_Insert(t *testing.T) { require.Equal(t, []string{}, jobRow.Tags) }) + t.Run("ProducerFetchLimiterCalled", func(t *testing.T) { + t.Parallel() + + var ( + driver = riversqlite.New(nil) + schema = riverdbtest.TestSchema(ctx, t, driver, &riverdbtest.TestSchemaOpts{ + ProcurePool: func(ctx context.Context, schema string) (any, string) { + return riversharedtest.DBPoolSQLite(ctx, t, schema), "" // could also be `main` instead of empty string + }, + }) + config = newTestConfig(t, schema) + ) + + client, err := NewClient(driver, config) + require.NoError(t, err) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + _, err = client.Insert(ctx, &noOpArgs{}, nil) + require.NoError(t, err) + + client.producersByQueueName[QueueDefault].testSignals.JobFetchTriggered.WaitOrTimeout() + }) + + // Not called for drivers that support a listener. + t.Run("ProducerFetchLimiterNotCalled", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + _, err := client.Insert(ctx, &noOpArgs{}, nil) + require.NoError(t, err) + + client.producersByQueueName[QueueDefault].testSignals.JobFetchTriggered.RequireEmpty() + }) + t.Run("WithInsertOpts", func(t *testing.T) { t.Parallel() @@ -4948,6 +5099,52 @@ func Test_Client_QueueUpdate(t *testing.T) { case <-time.After(100 * time.Millisecond): } }) + + t.Run("ProducerControlEventSent", func(t *testing.T) { + t.Parallel() + + var ( + driver = riversqlite.New(nil) + schema = riverdbtest.TestSchema(ctx, t, driver, &riverdbtest.TestSchemaOpts{ + ProcurePool: func(ctx context.Context, schema string) (any, string) { + return riversharedtest.DBPoolSQLite(ctx, t, schema), "" // could also be `main` instead of empty string + }, + }) + config = newTestConfig(t, schema) + ) + + client, err := NewClient(driver, config) + require.NoError(t, err) + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + _, err = client.QueueUpdate(ctx, QueueDefault, &QueueUpdateParams{ + Metadata: []byte(`{"foo":"baz"}`), + }) + require.NoError(t, err) + + controlEvent := client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.WaitOrTimeout() + require.NotNil(t, controlEvent) + require.Equal(t, controlActionMetadataChanged, controlEvent.Action) + }) + + t.Run("ProducerControlEventNotSent", func(t *testing.T) { + t.Parallel() + + client, _ := setup(t) + + client.producersByQueueName[QueueDefault].testSignals.Init(t) + + startClient(ctx, t, client) + + _, err := client.QueueUpdate(ctx, QueueDefault, &QueueUpdateParams{ + Metadata: []byte(`{"foo":"baz"}`), + }) + require.NoError(t, err) + + client.producersByQueueName[QueueDefault].testSignals.QueueControlEventTriggered.RequireEmpty() + }) } func Test_Client_QueueUpdateTx(t *testing.T) { diff --git a/example_sqlite_test.go b/example_sqlite_test.go index 976dc7a6..694e95a2 100644 --- a/example_sqlite_test.go +++ b/example_sqlite_test.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "log/slog" - "time" "github.com/riverqueue/river" "github.com/riverqueue/river/riverdriver" @@ -36,9 +35,7 @@ func Example_sqlite() { river.AddWorker(workers, &SortWorker{}) riverClient, err := river.NewClient(driver, &river.Config{ - FetchCooldown: 20 * time.Millisecond, - FetchPollInterval: 50 * time.Millisecond, // this driver is poll only, so speed up poll interval so the test runs fater - Logger: slog.New(&slogutil.SlogMessageOnlyHandler{Level: slog.LevelWarn}), + Logger: slog.New(&slogutil.SlogMessageOnlyHandler{Level: slog.LevelWarn}), Queues: map[string]river.QueueConfig{ river.QueueDefault: {MaxWorkers: 100}, }, diff --git a/internal/maintenance/job_scheduler.go b/internal/maintenance/job_scheduler.go index 9021ae5c..ac1eb9dc 100644 --- a/internal/maintenance/job_scheduler.go +++ b/internal/maintenance/job_scheduler.go @@ -14,6 +14,7 @@ import ( "github.com/riverqueue/river/rivershared/testsignal" "github.com/riverqueue/river/rivershared/util/randutil" "github.com/riverqueue/river/rivershared/util/serviceutil" + "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivershared/util/testutil" "github.com/riverqueue/river/rivershared/util/timeutil" "github.com/riverqueue/river/rivertype" @@ -35,11 +36,11 @@ func (ts *JobSchedulerTestSignals) Init(tb testutil.TestingTB) { ts.ScheduledBatch.Init(tb) } -type InsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) +type InsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error // NotifyInsert is a function to call to emit notifications for queues where // jobs were scheduled. -type NotifyInsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, queues []string) error +type NotifyInsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, queuesDeduped []string) error type JobSchedulerConfig struct { // Interval is the amount of time between periodic checks for jobs to @@ -186,7 +187,7 @@ func (s *JobScheduler) runOnce(ctx context.Context) (*schedulerRunOnceResult, er } if len(queues) > 0 { - if err := s.config.NotifyInsert(ctx, tx, queues); err != nil { + if err := s.config.NotifyInsert(ctx, tx, sliceutil.Uniq(queues)); err != nil { return 0, fmt.Errorf("error notifying insert: %w", err) } s.TestSignals.NotifiedQueues.Signal(queues) diff --git a/internal/maintenance/job_scheduler_test.go b/internal/maintenance/job_scheduler_test.go index 0c934778..5944d3d4 100644 --- a/internal/maintenance/job_scheduler_test.go +++ b/internal/maintenance/job_scheduler_test.go @@ -320,8 +320,8 @@ func TestJobScheduler(t *testing.T) { scheduler, _ := setup(t, &testOpts{exec: exec, schema: schema}) scheduler.config.Interval = time.Minute // should only trigger once for the initial run - scheduler.config.NotifyInsert = func(ctx context.Context, tx riverdriver.ExecutorTx, queues []string) error { - notifyCh <- queues + scheduler.config.NotifyInsert = func(ctx context.Context, tx riverdriver.ExecutorTx, queuesDeduped []string) error { + notifyCh <- queuesDeduped return nil } now := time.Now().UTC() @@ -362,10 +362,10 @@ func TestJobScheduler(t *testing.T) { require.NoError(t, scheduler.Start(ctx)) scheduler.TestSignals.ScheduledBatch.WaitOrTimeout() - expectedQueues := []string{"queue1", "queue2", "queue2", "queue3", "queue4"} + expectedQueues := []string{"queue1", "queue2", "queue3", "queue4"} - notifiedQueues := riversharedtest.WaitOrTimeout(t, notifyCh) - sort.Strings(notifiedQueues) - require.Equal(t, expectedQueues, notifiedQueues) + notifiedQueuesDeduped := riversharedtest.WaitOrTimeout(t, notifyCh) + sort.Strings(notifiedQueuesDeduped) + require.Equal(t, expectedQueues, notifiedQueuesDeduped) }) } diff --git a/internal/maintenance/periodic_job_enqueuer.go b/internal/maintenance/periodic_job_enqueuer.go index 90e5d6aa..fe11a834 100644 --- a/internal/maintenance/periodic_job_enqueuer.go +++ b/internal/maintenance/periodic_job_enqueuer.go @@ -342,8 +342,7 @@ func (s *PeriodicJobEnqueuer) insertBatch(ctx context.Context, insertParamsMany defer tx.Rollback(ctx) if len(insertParamsMany) > 0 { - _, err := s.Config.Insert(ctx, tx, insertParamsMany) - if err != nil { + if err := s.Config.Insert(ctx, tx, insertParamsMany); err != nil { s.Logger.ErrorContext(ctx, s.Name+": Error inserting periodic jobs", "error", err.Error(), "num_jobs", len(insertParamsMany)) return diff --git a/internal/maintenance/periodic_job_enqueuer_test.go b/internal/maintenance/periodic_job_enqueuer_test.go index 56bfbbd6..22a84707 100644 --- a/internal/maintenance/periodic_job_enqueuer_test.go +++ b/internal/maintenance/periodic_job_enqueuer_test.go @@ -79,22 +79,15 @@ func TestPeriodicJobEnqueuer(t *testing.T) { // A simplified version of `Client.insertMany` that only inserts jobs directly // via the driver instead of using the pilot. - makeInsertFunc := func(schema string) func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) { - return func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) { - results, err := tx.JobInsertFastMany(ctx, &riverdriver.JobInsertFastManyParams{ + makeInsertFunc := func(schema string) func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error { + return func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error { + _, err := tx.JobInsertFastMany(ctx, &riverdriver.JobInsertFastManyParams{ Jobs: sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams { return (*riverdriver.JobInsertFastParams)(params) }), Schema: schema, }) - if err != nil { - return nil, err - } - return sliceutil.Map(results, - func(result *riverdriver.JobInsertFastResult) *rivertype.JobInsertResult { - return (*rivertype.JobInsertResult)(result) - }, - ), nil + return err } } diff --git a/producer.go b/producer.go index f2417769..86fce3ad 100644 --- a/producer.go +++ b/producer.go @@ -41,21 +41,25 @@ const ( // Test-only properties. type producerTestSignals struct { - DeletedExpiredQueueRecords testsignal.TestSignal[struct{}] // notifies when the producer deletes expired queue records - MetadataChanged testsignal.TestSignal[struct{}] // notifies when the producer detects a metadata change - Paused testsignal.TestSignal[struct{}] // notifies when the producer is paused - PolledQueueConfig testsignal.TestSignal[struct{}] // notifies when the producer polls for queue settings - ReportedProducerStatus testsignal.TestSignal[struct{}] // notifies when the producer reports its own status - ReportedQueueStatus testsignal.TestSignal[struct{}] // notifies when the producer reports queue status - Resumed testsignal.TestSignal[struct{}] // notifies when the producer is resumed - StartedExecutors testsignal.TestSignal[struct{}] // notifies when runOnce finishes a pass + DeletedExpiredQueueRecords testsignal.TestSignal[struct{}] // notifies when the producer deletes expired queue records + JobFetchTriggered testsignal.TestSignal[struct{}] // notifies when the producer's fetch limiter is triggered via triggerJobFetch + MetadataChanged testsignal.TestSignal[struct{}] // notifies when the producer detects a metadata change + Paused testsignal.TestSignal[struct{}] // notifies when the producer is paused + PolledQueueConfig testsignal.TestSignal[struct{}] // notifies when the producer polls for queue settings + QueueControlEventTriggered testsignal.TestSignal[*controlEventPayload] // notifies when a queue control event is triggered via triggerQueueControlEvent + ReportedProducerStatus testsignal.TestSignal[struct{}] // notifies when the producer reports its own status + ReportedQueueStatus testsignal.TestSignal[struct{}] // notifies when the producer reports queue status + Resumed testsignal.TestSignal[struct{}] // notifies when the producer is resumed + StartedExecutors testsignal.TestSignal[struct{}] // notifies when runOnce finishes a pass } func (ts *producerTestSignals) Init(tb testutil.TestingTB) { ts.DeletedExpiredQueueRecords.Init(tb) + ts.JobFetchTriggered.Init(tb) ts.MetadataChanged.Init(tb) ts.Paused.Init(tb) ts.PolledQueueConfig.Init(tb) + ts.QueueControlEventTriggered.Init(tb) ts.ReportedQueueStatus.Init(tb) ts.ReportedProducerStatus.Init(tb) ts.Resumed.Init(tb) @@ -182,6 +186,7 @@ type producer struct { id atomic.Int64 // atomic because it's written at startup and read during shutdown exec riverdriver.Executor errorHandler jobexecutor.ErrorHandler + fetchLimiter *chanutil.DebouncedChan state riverpilot.ProducerState pilot riverpilot.Pilot workers *Workers @@ -324,7 +329,7 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { p.id.Store(id) // TODO: fetcher should have some jitter in it to avoid stampeding issues. - fetchLimiter := chanutil.NewDebouncedChan(fetchCtx, p.config.FetchCooldown, true) + p.fetchLimiter = chanutil.NewDebouncedChan(fetchCtx, p.config.FetchCooldown, true) var ( controlSub *notifier.Subscription @@ -343,7 +348,7 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { return } p.Logger.DebugContext(workCtx, p.Name+": Received insert notification", slog.String("queue", decoded.Queue)) - fetchLimiter.Call() + p.fetchLimiter.Call() } insertSub, err = p.config.Notifier.Listen(fetchCtx, notifier.NotificationTopicInsert, handleInsertNotification) if err != nil { @@ -395,7 +400,7 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { go p.pollForSettingChanges(subroutineCtx, &subroutineWG, initiallyPaused, initialMetadata) } - p.fetchAndRunLoop(fetchCtx, workCtx, fetchLimiter) + p.fetchAndRunLoop(fetchCtx, workCtx) p.Logger.Debug(p.Name+": Entering shutdown loop", slog.String("queue", p.config.Queue), slog.Int64("id", p.id.Load())) p.executorShutdownLoop() @@ -410,6 +415,29 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { return nil } +// TriggerJobFetch manually triggers the producer to perform a job fetch +// (although it's debounced, so it may not happen immediately if a fetch was +// performed very recently). This is used by clients using drivers that don't +// support listeners to wake a producer immediately after a job insert was known +// to be performed so the producer doesn't have to wait on polling. +func (p *producer) TriggerJobFetch() { + if p.fetchLimiter != nil { + p.fetchLimiter.Call() + } + p.testSignals.JobFetchTriggered.Signal(struct{}{}) +} + +// TriggerQueueControlEvent manually injects a queue control event into the +// producer's queue control channel as if it'd been received through +// listen/notify. This is used by clients using drivers that don't support +// listeners to wake a producer immediately after a queue control event was +// known to be performed so the producer doesn't have to wait on polling. +func (p *producer) TriggerQueueControlEvent(controlEvent *controlEventPayload) { + p.queueControlCh <- controlEvent + p.testSignals.QueueControlEventTriggered.Signal(controlEvent) + +} + type controlAction string const ( @@ -475,10 +503,10 @@ func (p *producer) handleControlNotification(workCtx context.Context) func(notif } } -func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimiter *chanutil.DebouncedChan) { +func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context) { // Prime the fetchLimiter so we can make an initial fetch without waiting for // an insert notification or a fetch poll. - fetchLimiter.Call() + p.fetchLimiter.Call() fetchPollTimer := time.NewTimer(p.config.FetchPollInterval) go func() { @@ -491,7 +519,7 @@ func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimit } return case <-fetchPollTimer.C: - fetchLimiter.Call() + p.fetchLimiter.Call() fetchPollTimer.Reset(p.config.FetchPollInterval) } } @@ -533,7 +561,7 @@ func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimit } p.paused = false p.Logger.DebugContext(workCtx, p.Name+": Resumed", slog.String("queue", p.config.Queue), slog.String("queue_in_message", msg.Queue)) - fetchLimiter.Call() // try another fetch because more jobs may be available to run which were gated behind the paused queue + p.fetchLimiter.Call() // try another fetch because more jobs may be available to run which were gated behind the paused queue p.testSignals.Resumed.Signal(struct{}{}) if p.config.QueueEventCallback != nil { p.config.QueueEventCallback(&Event{Kind: EventKindQueueResumed, Queue: &rivertype.Queue{Name: p.config.Queue}}) @@ -543,7 +571,7 @@ func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimit } case jobID := <-p.cancelCh: p.maybeCancelJob(jobID) - case <-fetchLimiter.C(): + case <-p.fetchLimiter.C(): p.innerFetchLoop(workCtx, fetchResultCh) // Ensure we can't start another fetch when fetchCtx is done, even if // the fetchLimiter is also ready to fire: @@ -560,7 +588,7 @@ func (p *producer) fetchAndRunLoop(fetchCtx, workCtx context.Context, fetchLimit // more aggressive triggering the fetch limiter now that we have a slot // available. p.fetchWhenSlotsAreAvailable = false - fetchLimiter.Call() + p.fetchLimiter.Call() } } } diff --git a/rivershared/riverpilot/pilot.go b/rivershared/riverpilot/pilot.go index c91e32e8..a6cee157 100644 --- a/rivershared/riverpilot/pilot.go +++ b/rivershared/riverpilot/pilot.go @@ -25,11 +25,11 @@ type Pilot interface { JobInsertMany( ctx context.Context, - tx riverdriver.ExecutorTx, + execTx riverdriver.ExecutorTx, params *riverdriver.JobInsertFastManyParams, ) ([]*riverdriver.JobInsertFastResult, error) - JobSetStateIfRunningMany(ctx context.Context, tx riverdriver.ExecutorTx, params *riverdriver.JobSetStateIfRunningManyParams) ([]*rivertype.JobRow, error) + JobSetStateIfRunningMany(ctx context.Context, execTx riverdriver.ExecutorTx, params *riverdriver.JobSetStateIfRunningManyParams) ([]*rivertype.JobRow, error) PilotInit(archetype *baseservice.Archetype) diff --git a/rivershared/riverpilot/standard.go b/rivershared/riverpilot/standard.go index 30fe773e..ea12b207 100644 --- a/rivershared/riverpilot/standard.go +++ b/rivershared/riverpilot/standard.go @@ -22,14 +22,14 @@ func (p *StandardPilot) JobGetAvailable(ctx context.Context, exec riverdriver.Ex func (p *StandardPilot) JobInsertMany( ctx context.Context, - tx riverdriver.ExecutorTx, + execTx riverdriver.ExecutorTx, params *riverdriver.JobInsertFastManyParams, ) ([]*riverdriver.JobInsertFastResult, error) { - return tx.JobInsertFastMany(ctx, params) + return execTx.JobInsertFastMany(ctx, params) } -func (p *StandardPilot) JobSetStateIfRunningMany(ctx context.Context, tx riverdriver.ExecutorTx, params *riverdriver.JobSetStateIfRunningManyParams) ([]*rivertype.JobRow, error) { - return tx.JobSetStateIfRunningMany(ctx, params) +func (p *StandardPilot) JobSetStateIfRunningMany(ctx context.Context, execTx riverdriver.ExecutorTx, params *riverdriver.JobSetStateIfRunningManyParams) ([]*rivertype.JobRow, error) { + return execTx.JobSetStateIfRunningMany(ctx, params) } func (p *StandardPilot) PilotInit(archetype *baseservice.Archetype) { diff --git a/rivershared/testsignal/test_signal.go b/rivershared/testsignal/test_signal.go index b8a20b03..845502a1 100644 --- a/rivershared/testsignal/test_signal.go +++ b/rivershared/testsignal/test_signal.go @@ -1,7 +1,6 @@ package testsignal import ( - "fmt" "time" "github.com/riverqueue/river/rivershared/riversharedtest" @@ -38,6 +37,20 @@ func (s *TestSignal[T]) Init(tb testutil.TestingTB) { s.tb = tb } +// RequireEmpty requires that the test signal be empty (i.e. have not received +// any values). +func (s *TestSignal[T]) RequireEmpty() { + if s.internalChan == nil { + panic("test only signal is not initialized; called outside of tests?") + } + + select { + case val := <-s.internalChan: + s.tb.Errorf("test signal should be empty, but wasn't\ngot value: %v\n", val) + default: + } +} + // Signal signals the test signal. In production where the signal hasn't been // initialized, this no ops harmlessly. In tests, the value is written to an // internal asynchronous channel which can be waited with WaitOrTimeout. @@ -52,7 +65,6 @@ func (s *TestSignal[T]) Signal(val T) { case s.internalChan <- val: default: s.tb.Errorf("test only signal channel is full") - s.tb.FailNow() } } @@ -77,9 +89,12 @@ func (s *TestSignal[T]) WaitOrTimeout() T { timeout := riversharedtest.WaitTimeout() select { - case value := <-s.internalChan: - return value + case val := <-s.internalChan: + return val case <-time.After(timeout): - panic(fmt.Sprintf("timed out waiting on test signal after %s", timeout)) + s.tb.Errorf("timed out waiting on test signal after %s", timeout) } + + var val T + return val } diff --git a/rivershared/testsignal/test_signal_test.go b/rivershared/testsignal/test_signal_test.go index 2f1bd656..6e8e22d1 100644 --- a/rivershared/testsignal/test_signal_test.go +++ b/rivershared/testsignal/test_signal_test.go @@ -49,6 +49,27 @@ func TestTestSignal(t *testing.T) { } }) + t.Run("RequireEmpty", func(t *testing.T) { + t.Parallel() + + signal := TestSignal[struct{}]{} + + require.PanicsWithValue(t, "test only signal is not initialized; called outside of tests?", func() { + signal.RequireEmpty() + }) + + mockT := testutil.NewMockT(t) + signal.Init(mockT) + + signal.RequireEmpty() // succeeds + + signal.Signal(struct{}{}) + + signal.RequireEmpty() + require.True(t, mockT.Failed) + require.Equal(t, "test signal should be empty, but wasn't\ngot value: {}\n\n", mockT.LogOutput()) + }) + t.Run("WaitC", func(t *testing.T) { t.Parallel() @@ -74,6 +95,11 @@ func TestTestSignal(t *testing.T) { t.Parallel() signal := TestSignal[struct{}]{} + + require.PanicsWithValue(t, "test only signal is not initialized; called outside of tests?", func() { + signal.WaitOrTimeout() + }) + signal.Init(t) signal.Signal(struct{}{}) diff --git a/rivershared/util/sliceutil/slice_util.go b/rivershared/util/sliceutil/slice_util.go index a0b7262a..945de5af 100644 --- a/rivershared/util/sliceutil/slice_util.go +++ b/rivershared/util/sliceutil/slice_util.go @@ -54,3 +54,21 @@ func Map[T any, R any](collection []T, mapFunc func(T) R) []R { return result } + +// Uniq returns a duplicate-free version of an array, in which only the first occurrence of each element is kept. +// The order of result values is determined by the order they occur in the array. +func Uniq[T comparable](collection []T) []T { + result := make([]T, 0, len(collection)) + seen := make(map[T]struct{}, len(collection)) + + for _, item := range collection { + if _, ok := seen[item]; ok { + continue + } + + seen[item] = struct{}{} + result = append(result, item) + } + + return result +} diff --git a/rivershared/util/sliceutil/slice_util_test.go b/rivershared/util/sliceutil/slice_util_test.go index 9026df7e..bdac4bb1 100644 --- a/rivershared/util/sliceutil/slice_util_test.go +++ b/rivershared/util/sliceutil/slice_util_test.go @@ -89,3 +89,12 @@ func TestMap(t *testing.T) { require.Equal(t, []string{"Hello", "Hello", "Hello", "Hello"}, result1) require.Equal(t, []string{"1", "2", "3", "4"}, result2) } + +func TestUniq(t *testing.T) { + t.Parallel() + + result1 := Uniq([]int{1, 2, 2, 1}) + + require.Len(t, result1, 2) + require.Equal(t, []int{1, 2}, result1) +}