Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 50 additions & 50 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ import (
"github.com/riverqueue/river/internal/notifier"
"github.com/riverqueue/river/internal/notifylimiter"
"github.com/riverqueue/river/internal/rivercommon"
"github.com/riverqueue/river/internal/util/dbutil"
"github.com/riverqueue/river/internal/workunit"
"github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/rivershared/baseservice"
"github.com/riverqueue/river/rivershared/riverpilot"
"github.com/riverqueue/river/rivershared/startstop"
"github.com/riverqueue/river/rivershared/testsignal"
"github.com/riverqueue/river/rivershared/util/dbutil"
"github.com/riverqueue/river/rivershared/util/maputil"
"github.com/riverqueue/river/rivershared/util/sliceutil"
"github.com/riverqueue/river/rivershared/util/testutil"
Expand Down Expand Up @@ -796,9 +796,11 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
if client.pilot == nil {
client.pilot = &riverpilot.StandardPilot{}
}
client.pilot.PilotInit(archetype, &riverpilot.PilotInitParams{
WorkerMetadata: workerMetadata,
})
client.pilot.PilotInit(archetype, (&riverpilot.PilotInitParams{
Insert: client.insertMany,
NotifyNonTxJobInsert: client.notifyProducerWithoutListenerJobFetch,
WorkerMetadata: workerMetadata,
}).Validate())
pluginPilot, _ := client.pilot.(pilotPlugin)

if withBaseService, ok := config.RetryPolicy.(baseservice.WithBaseService); ok {
Expand Down Expand Up @@ -898,12 +900,9 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
{
periodicJobEnqueuer, err := maintenance.NewPeriodicJobEnqueuer(archetype, &maintenance.PeriodicJobEnqueuerConfig{
AdvisoryLockPrefix: config.AdvisoryLockPrefix,
Insert: func(ctx context.Context, execTx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error {
_, err := client.insertMany(ctx, execTx, insertParams)
return err
},
Pilot: client.pilot,
Schema: config.Schema,
Insert: client.insertMany,
Pilot: client.pilot,
Schema: config.Schema,
}, driver.GetExecutor())
if err != nil {
return nil, err
Expand Down Expand Up @@ -1632,16 +1631,16 @@ func (c *Client[TTx]) Insert(ctx context.Context, args JobArgs, opts *InsertOpts
return nil, errNoDriverDBPool
}

res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) (*insertManySharedResult, error) {
res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) ([]*rivertype.JobInsertResult, error) {
return c.validateParamsAndInsertMany(ctx, execTx, []InsertManyParams{{Args: args, InsertOpts: opts}})
})
if err != nil {
return nil, err
}

c.notifyProducerWithoutListenerJobFetch(res.QueuesDeduped)
c.notifyProducerWithoutListenerJobFetch(ctx, res)

return res.InsertResults[0], nil
return res[0], nil
}

// InsertTx inserts a new job with the provided args on the given transaction.
Expand All @@ -1666,7 +1665,7 @@ func (c *Client[TTx]) InsertTx(ctx context.Context, tx TTx, args JobArgs, opts *
if err != nil {
return nil, err
}
return res.InsertResults[0], nil
return res[0], nil
}

// InsertManyParams encapsulates a single job combined with insert options for
Expand Down Expand Up @@ -1698,16 +1697,16 @@ func (c *Client[TTx]) InsertMany(ctx context.Context, params []InsertManyParams)
return nil, errNoDriverDBPool
}

res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) (*insertManySharedResult, error) {
res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) ([]*rivertype.JobInsertResult, error) {
return c.validateParamsAndInsertMany(ctx, execTx, params)
})
if err != nil {
return nil, err
}

c.notifyProducerWithoutListenerJobFetch(res.QueuesDeduped)
c.notifyProducerWithoutListenerJobFetch(ctx, res)

return res.InsertResults, nil
return res, nil
}

// InsertManyTx inserts many jobs at once. Each job is inserted as an
Expand All @@ -1733,14 +1732,14 @@ func (c *Client[TTx]) InsertManyTx(ctx context.Context, tx TTx, params []InsertM
if err != nil {
return nil, err
}
return res.InsertResults, nil
return res, nil
}

// validateParamsAndInsertMany is a helper method that wraps the insertMany
// method to provide param validation and conversion prior to calling the actual
// insertMany method. This allows insertMany to be reused by the
// PeriodicJobEnqueuer which cannot reference top-level river package types.
func (c *Client[TTx]) validateParamsAndInsertMany(ctx context.Context, execTx riverdriver.ExecutorTx, params []InsertManyParams) (*insertManySharedResult, error) {
func (c *Client[TTx]) validateParamsAndInsertMany(ctx context.Context, execTx riverdriver.ExecutorTx, params []InsertManyParams) ([]*rivertype.JobInsertResult, error) {
insertParams, err := c.insertManyParams(params)
if err != nil {
return nil, err
Expand All @@ -1751,7 +1750,7 @@ func (c *Client[TTx]) validateParamsAndInsertMany(ctx context.Context, execTx ri

// insertMany is a shared code path for InsertMany and InsertManyTx, also used
// by the PeriodicJobEnqueuer.
func (c *Client[TTx]) insertMany(ctx context.Context, execTx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) (*insertManySharedResult, error) {
func (c *Client[TTx]) insertMany(ctx context.Context, execTx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) {
return c.insertManyShared(ctx, execTx, insertParams, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) {
results, err := c.pilot.JobInsertMany(ctx, execTx, &riverdriver.JobInsertFastManyParams{
Jobs: insertParams,
Expand All @@ -1769,11 +1768,6 @@ func (c *Client[TTx]) insertMany(ctx context.Context, execTx riverdriver.Executo
})
}

type insertManySharedResult struct {
InsertResults []*rivertype.JobInsertResult
QueuesDeduped []string
}

// The shared code path for all Insert and InsertMany methods. It takes a
// function that executes the actual insert operation and allows for different
// implementations of the insert query to be passed in, each mapping their
Expand All @@ -1783,9 +1777,7 @@ func (c *Client[TTx]) insertManyShared(
tx riverdriver.ExecutorTx,
insertParams []*rivertype.JobInsertParams,
execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error),
) (*insertManySharedResult, error) {
var queuesDeduped []string

) ([]*rivertype.JobInsertResult, error) {
doInner := func(ctx context.Context) ([]*rivertype.JobInsertResult, error) {
for _, params := range insertParams {
for _, hook := range append(
Expand Down Expand Up @@ -1814,9 +1806,7 @@ func (c *Client[TTx]) insertManyShared(
}
}

queuesDeduped = sliceutil.Uniq(queues)

if err = c.maybeNotifyInsertForQueues(ctx, tx, queuesDeduped); err != nil {
if err = c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil {
return nil, err
}

Expand All @@ -1836,15 +1826,7 @@ func (c *Client[TTx]) insertManyShared(
}
}

insertResults, err := doInner(ctx)
if err != nil {
return nil, err
}

return &insertManySharedResult{
InsertResults: insertResults,
QueuesDeduped: queuesDeduped,
}, nil
return doInner(ctx)
}

// Validates input parameters for a batch insert operation and generates a set
Expand Down Expand Up @@ -1878,13 +1860,30 @@ func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*rivertype.
// Should only ever be invoked *outside* a transaction. If invoked within a
// transaction, the producer wouldn't yet be able to access the new jobs that
// triggered the notification because they're not committed yet.
func (c *Client[TTx]) notifyProducerWithoutListenerJobFetch(queuesDeduped []string) {
func (c *Client[TTx]) notifyProducerWithoutListenerJobFetch(_ context.Context, res []*rivertype.JobInsertResult) {
if c.driver.SupportsListener() || len(c.producersByQueueName) < 1 {
return
}

for _, queue := range queuesDeduped {
if producer, ok := c.producersByQueueName[queue]; ok {
// Special case for when we were handling exactly one job, which is a very
// common case. Acts as a minor optimization by avoiding the map allocation.
if len(res) == 1 {
if producer, ok := c.producersByQueueName[res[0].Job.Queue]; ok {
producer.TriggerJobFetch()
}

return
}

queuesTriggered := make(map[string]struct{})

for _, insertRes := range res {
if _, ok := queuesTriggered[insertRes.Job.Queue]; ok {
continue
}
queuesTriggered[insertRes.Job.Queue] = struct{}{}

if producer, ok := c.producersByQueueName[insertRes.Job.Queue]; ok {
producer.TriggerJobFetch()
}
}
Expand Down Expand Up @@ -1914,16 +1913,16 @@ func (c *Client[TTx]) InsertManyFast(ctx context.Context, params []InsertManyPar
}

// Wrap in a transaction in case we need to notify about inserts.
res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) (*insertManySharedResult, error) {
res, err := dbutil.WithTxV(ctx, c.driver.GetExecutor(), func(ctx context.Context, execTx riverdriver.ExecutorTx) ([]*rivertype.JobInsertResult, error) {
return c.insertManyFast(ctx, execTx, params)
})
if err != nil {
return 0, err
}

c.notifyProducerWithoutListenerJobFetch(res.QueuesDeduped)
c.notifyProducerWithoutListenerJobFetch(ctx, res)

return len(res.InsertResults), nil
return len(res), nil
}

// InsertManyTx inserts many jobs at once using Postgres' `COPY FROM` mechanism,
Expand Down Expand Up @@ -1954,10 +1953,10 @@ func (c *Client[TTx]) InsertManyFastTx(ctx context.Context, tx TTx, params []Ins
if err != nil {
return 0, err
}
return len(res.InsertResults), nil
return len(res), nil
}

func (c *Client[TTx]) insertManyFast(ctx context.Context, execTx riverdriver.ExecutorTx, params []InsertManyParams) (*insertManySharedResult, error) {
func (c *Client[TTx]) insertManyFast(ctx context.Context, execTx riverdriver.ExecutorTx, params []InsertManyParams) ([]*rivertype.JobInsertResult, error) {
insertParams, err := c.insertManyParams(params)
if err != nil {
return nil, err
Expand All @@ -1978,12 +1977,13 @@ func (c *Client[TTx]) insertManyFast(ctx context.Context, execTx riverdriver.Exe
// Notify the given queues that new jobs are available. The queues list will be
// deduplicated and each will be checked to see if it is due for an insert
// notification from this client.
func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdriver.ExecutorTx, queuesDeduped []string) error {
if len(queuesDeduped) < 1 {
func (c *Client[TTx]) maybeNotifyInsertForQueues(ctx context.Context, tx riverdriver.ExecutorTx, queues []string) error {
if len(queues) < 1 {
return nil
}

var (
queuesDeduped = sliceutil.Uniq(queues)
payloads = make([]string, 0, len(queuesDeduped))
queuesTriggered = make([]string, 0, len(queuesDeduped))
)
Expand Down
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ import (
"github.com/riverqueue/river/internal/rivercommon"
"github.com/riverqueue/river/internal/riverinternaltest"
"github.com/riverqueue/river/internal/riverinternaltest/retrypolicytest"
"github.com/riverqueue/river/internal/util/dbutil"
"github.com/riverqueue/river/riverdbtest"
"github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/riverdriver/riverpgxv5"
"github.com/riverqueue/river/rivershared/baseservice"
"github.com/riverqueue/river/rivershared/riversharedtest"
"github.com/riverqueue/river/rivershared/startstoptest"
"github.com/riverqueue/river/rivershared/testfactory"
"github.com/riverqueue/river/rivershared/util/dbutil"
"github.com/riverqueue/river/rivershared/util/ptrutil"
"github.com/riverqueue/river/rivershared/util/randutil"
"github.com/riverqueue/river/rivershared/util/serviceutil"
Expand Down
2 changes: 1 addition & 1 deletion internal/leadership/elector.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
"time"

"github.com/riverqueue/river/internal/notifier"
"github.com/riverqueue/river/internal/util/dbutil"
"github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/rivershared/baseservice"
"github.com/riverqueue/river/rivershared/startstop"
"github.com/riverqueue/river/rivershared/testsignal"
"github.com/riverqueue/river/rivershared/util/dbutil"
"github.com/riverqueue/river/rivershared/util/randutil"
"github.com/riverqueue/river/rivershared/util/serviceutil"
"github.com/riverqueue/river/rivershared/util/testutil"
Expand Down
8 changes: 2 additions & 6 deletions internal/maintenance/job_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ import (
"github.com/riverqueue/river/rivershared/testsignal"
"github.com/riverqueue/river/rivershared/util/randutil"
"github.com/riverqueue/river/rivershared/util/serviceutil"
"github.com/riverqueue/river/rivershared/util/sliceutil"
"github.com/riverqueue/river/rivershared/util/testutil"
"github.com/riverqueue/river/rivershared/util/timeutil"
"github.com/riverqueue/river/rivertype"
)

const (
Expand All @@ -36,11 +34,9 @@ func (ts *JobSchedulerTestSignals) Init(tb testutil.TestingTB) {
ts.ScheduledBatch.Init(tb)
}

type InsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error

// NotifyInsert is a function to call to emit notifications for queues where
// jobs were scheduled.
type NotifyInsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, queuesDeduped []string) error
type NotifyInsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, queues []string) error

type JobSchedulerConfig struct {
// Interval is the amount of time between periodic checks for jobs to
Expand Down Expand Up @@ -187,7 +183,7 @@ func (s *JobScheduler) runOnce(ctx context.Context) (*schedulerRunOnceResult, er
}

if len(queues) > 0 {
if err := s.config.NotifyInsert(ctx, tx, sliceutil.Uniq(queues)); err != nil {
if err := s.config.NotifyInsert(ctx, tx, queues); err != nil {
return 0, fmt.Errorf("error notifying insert: %w", err)
}
s.TestSignals.NotifiedQueues.Signal(queues)
Expand Down
12 changes: 5 additions & 7 deletions internal/maintenance/job_scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ func TestJobScheduler(t *testing.T) {

scheduler, _ := setup(t, &testOpts{exec: exec, schema: schema})
scheduler.config.Interval = time.Minute // should only trigger once for the initial run
scheduler.config.NotifyInsert = func(ctx context.Context, tx riverdriver.ExecutorTx, queuesDeduped []string) error {
notifyCh <- queuesDeduped
scheduler.config.NotifyInsert = func(ctx context.Context, tx riverdriver.ExecutorTx, queues []string) error {
notifyCh <- queues
return nil
}
now := time.Now().UTC()
Expand Down Expand Up @@ -362,10 +362,8 @@ func TestJobScheduler(t *testing.T) {
require.NoError(t, scheduler.Start(ctx))
scheduler.TestSignals.ScheduledBatch.WaitOrTimeout()

expectedQueues := []string{"queue1", "queue2", "queue3", "queue4"}

notifiedQueuesDeduped := riversharedtest.WaitOrTimeout(t, notifyCh)
sort.Strings(notifiedQueuesDeduped)
require.Equal(t, expectedQueues, notifiedQueuesDeduped)
notifiedQueues := riversharedtest.WaitOrTimeout(t, notifyCh)
sort.Strings(notifiedQueues)
require.Equal(t, []string{"queue1", "queue2", "queue2", "queue3", "queue4"}, notifiedQueues)
})
}
4 changes: 3 additions & 1 deletion internal/maintenance/periodic_job_enqueuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ func (j *PeriodicJob) validate() error {
return nil
}

type InsertFunc func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error)

type PeriodicJobEnqueuerConfig struct {
AdvisoryLockPrefix int32

Expand Down Expand Up @@ -450,7 +452,7 @@ func (s *PeriodicJobEnqueuer) insertBatch(ctx context.Context, insertParamsMany
defer tx.Rollback(ctx)

if len(insertParamsMany) > 0 {
if err := s.Config.Insert(ctx, tx, insertParamsMany); err != nil {
if _, err := s.Config.Insert(ctx, tx, insertParamsMany); err != nil {
s.Logger.ErrorContext(ctx, s.Name+": Error inserting periodic jobs",
"error", err.Error(), "num_jobs", len(insertParamsMany))
}
Expand Down
6 changes: 3 additions & 3 deletions internal/maintenance/periodic_job_enqueuer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ func TestPeriodicJobEnqueuer(t *testing.T) {

// A simplified version of `Client.insertMany` that only inserts jobs directly
// via the driver instead of using the pilot.
makeInsertFunc := func(schema string) func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error {
return func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) error {
makeInsertFunc := func(schema string) func(ctx context.Context, execTx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) {
return func(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*rivertype.JobInsertParams) ([]*rivertype.JobInsertResult, error) {
_, err := tx.JobInsertFastMany(ctx, &riverdriver.JobInsertFastManyParams{
Jobs: sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams {
return (*riverdriver.JobInsertFastParams)(params)
}),
Schema: schema,
})
return err
return nil, err
}
}

Expand Down
1 change: 1 addition & 0 deletions riverdriver/river_driver_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ type JobGetStuckParams struct {
}

type JobInsertFastParams struct {
ID *int64
// Args contains the raw underlying job arguments struct. It has already been
// encoded into EncodedArgs, but the original is kept here for to leverage its
// struct tags and interfaces, such as for use in unique key generation.
Expand Down
Loading
Loading