diff --git a/src/commands/config.go b/src/commands/config.go index 2ddf81c..bff79e3 100644 --- a/src/commands/config.go +++ b/src/commands/config.go @@ -39,11 +39,7 @@ var command_config_main = Command{ metadata: nil, handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) { // Fetch config for this guild - db, err := database.Connect() - if err != nil { - log.Printf("Failed to connect to database: %v", err) - return - } + db := database.Connect() c := db.GetConfig(i.GuildID) j, err := json.MarshalIndent(c, "", " ") @@ -83,18 +79,14 @@ var command_config_channel = Command{ }, handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) { // Fetch config for this guild - db, err := database.Connect() - if err != nil { - log.Printf("Failed to connect to database: %v", err) - return - } + db := database.Connect() c := db.GetConfig(i.GuildID) // Write changes to config and save it new_value := i.ApplicationCommandData().Options[option].ChannelValue(discord).ID if c.Channel != new_value { c.Channel = new_value - err = db.SaveConfig(i.GuildID, c) + err := db.SaveConfig(i.GuildID, c) if err != nil { log.Printf("Failed to save config: %v", err) return @@ -124,18 +116,14 @@ var command_config_threshold = Command{ }, handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) { // Fetch config for this guild - db, err := database.Connect() - if err != nil { - log.Printf("Failed to connect to database: %v", err) - return - } + db := database.Connect() c := db.GetConfig(i.GuildID) // Write changes to config and save it new_value := int(i.ApplicationCommandData().Options[option].IntValue()) if c.Threshold != new_value { c.Threshold = new_value - err = db.SaveConfig(i.GuildID, c) + err := db.SaveConfig(i.GuildID, c) if err != nil { log.Printf("Failed to save config: %v", err) return @@ -163,18 +151,14 @@ var command_config_nsfw = Command{ }, handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) { // Fetch config for this guild - db, err := database.Connect() - if err != nil { - log.Printf("Failed to connect to database: %v", err) - return - } + db := database.Connect() c := db.GetConfig(i.GuildID) // Write changes to config and save it new_value := i.ApplicationCommandData().Options[option].BoolValue() if c.NSFW != new_value { c.NSFW = new_value - err = db.SaveConfig(i.GuildID, c) + err := db.SaveConfig(i.GuildID, c) if err != nil { log.Printf("Failed to save config: %v", err) return @@ -208,18 +192,14 @@ var command_config_selfpin = Command{ }, handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) { // Fetch config for this guild - db, err := database.Connect() - if err != nil { - log.Printf("Failed to connect to database: %v", err) - return - } + db := database.Connect() c := db.GetConfig(i.GuildID) // Write changes to config and save it new_value := i.ApplicationCommandData().Options[option].BoolValue() if c.Selfpin != new_value { c.Selfpin = new_value - err = db.SaveConfig(i.GuildID, c) + err := db.SaveConfig(i.GuildID, c) if err != nil { log.Printf("Failed to save config: %v", err) return @@ -256,18 +236,14 @@ var command_config_replydepth = Command{ }, handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) { // Fetch config for this guild - db, err := database.Connect() - if err != nil { - log.Printf("Failed to connect to database: %v", err) - return - } + db := database.Connect() c := db.GetConfig(i.GuildID) // Write changes to config and save it new_value := int(i.ApplicationCommandData().Options[option].IntValue()) if c.ReplyDepth != new_value { c.ReplyDepth = new_value - err = db.SaveConfig(i.GuildID, c) + err := db.SaveConfig(i.GuildID, c) if err != nil { log.Printf("Failed to save config: %v", err) return @@ -295,11 +271,7 @@ var command_config_emoji = Command{ }, handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) { // Fetch config for this guild - db, err := database.Connect() - if err != nil { - log.Printf("Failed to connect to database: %v", err) - return - } + db := database.Connect() c := db.GetConfig(i.GuildID) // Write changes to config and save it @@ -316,7 +288,7 @@ var command_config_emoji = Command{ c.Allowlist[emoji] = struct{}{} } - err = db.SaveConfig(i.GuildID, c) + err := db.SaveConfig(i.GuildID, c) if err != nil { log.Printf("Failed to save config: %v", err) return diff --git a/src/commands/stats.go b/src/commands/stats.go index ed35662..18350ce 100644 --- a/src/commands/stats.go +++ b/src/commands/stats.go @@ -45,10 +45,7 @@ var command_stats_leaderboard = Command{ }) // Connect to database - db, err := database.Connect() - if err != nil { - log.Printf("Failed to connect to database: %v", err) - } + db := database.Connect() lb, err := db.GetLeaderboard(i.GuildID) if err != nil { @@ -116,10 +113,7 @@ var command_stats_user = Command{ if user := i.ApplicationCommandData().Options[0].UserValue(discord); user != nil { // Connect to database - db, err := database.Connect() - if err != nil { - log.Printf("Failed to connect to database: %v", err) - } + db := database.Connect() // Build embed header if member, err := discord.GuildMember(i.GuildID, user.ID); err == nil { diff --git a/src/database/config.go b/src/database/config.go index 7e6d3f7..44edac9 100644 --- a/src/database/config.go +++ b/src/database/config.go @@ -35,7 +35,7 @@ func (db *database) createConfigTable() error { PRIMARY KEY (guild_id) ) `) - _, err = db.Instance.ExecContext(context.Background(), query) + _, err := db.Instance.ExecContext(context.Background(), query) if err != nil { return fmt.Errorf("Failed to create config table: %w", err) } @@ -75,13 +75,13 @@ func (db *database) LoadConfig(guild_id string) (*Config, error) { // Retrieve config data from database, load default if not found var raw string - err = db.Instance.QueryRow("SELECT json FROM config WHERE guild_id = ?", guild_id).Scan(&raw) + err := db.Instance.QueryRow("SELECT json FROM config WHERE guild_id = ?", guild_id).Scan(&raw) if err != nil { return c, fmt.Errorf("Failed to retrieve config for guild '%s': %w", guild_id, err) } // Load config into new object - err := json.Unmarshal([]byte(raw), c) + err = json.Unmarshal([]byte(raw), c) if err != nil { return c, fmt.Errorf("Failed to unmarshal config for guild '%s': %w", guild_id, err) } @@ -94,7 +94,7 @@ func (db *database) SaveConfig(guild_id string, c *Config) error { log.Printf("Saving config for guild '%s'\n", guild_id) // Create guild pins table if it doesn't exist - err = db.createConfigTable() + err := db.createConfigTable() if err != nil { return fmt.Errorf("Failed to create config table: %w", err) } diff --git a/src/database/main.go b/src/database/main.go index 48168ec..ef05845 100644 --- a/src/database/main.go +++ b/src/database/main.go @@ -17,10 +17,9 @@ type database struct { var db *database var once sync.Once -var err error // Connects to database and returns a pointer to it -func Connect() (*database, error) { +func Connect() *database { // Only runs once once.Do(func(){ // Retrieve file name from the environment @@ -39,11 +38,12 @@ func Connect() (*database, error) { // Limit writes to one at a time (not ideal) instance.SetMaxOpenConns(1) - // Assign to global variable - err = instance.Ping() + if err := instance.Ping(); err != nil { + log.Fatal("Failed to ping database: ", err) + } + db = &database{instance, make(map[string]*Config)} }) - // Returns current or new sqlite3 instance (after testing with ping) - return db, err + return db } diff --git a/src/database/msgs.go b/src/database/msgs.go index 85d19ed..4293bdb 100644 --- a/src/database/msgs.go +++ b/src/database/msgs.go @@ -15,7 +15,7 @@ func (db *database) createPinTable(guild_id string) error { PRIMARY KEY (message_id, pin_channel_id, pin_id) ) `, guild_id) - _, err = db.Instance.ExecContext(context.Background(), query, guild_id) + _, err := db.Instance.ExecContext(context.Background(), query, guild_id) if err != nil { return fmt.Errorf("Failed to create pins_%s table: %w", guild_id, err) } @@ -25,7 +25,7 @@ func (db *database) createPinTable(guild_id string) error { // AddPin inserts a message_id -> pin_id pair into the guild_id table. func (db *database) AddPin(guild_id string, pin_channel_id string, message_id string, pin_id string) error { // Create guild pins table if it doesn't exist - err = db.createPinTable(guild_id) + err := db.createPinTable(guild_id) if err != nil { return err } @@ -42,7 +42,7 @@ func (db *database) AddPin(guild_id string, pin_channel_id string, message_id st // GetPin retrieves the pin message id from the guild_id table given a guild_id and message_id. func (db *database) GetPin(guild_id string, message_id string) (string, string, error) { // Create guild pins table if it doesn't exist - err = db.createPinTable(guild_id) + err := db.createPinTable(guild_id) if err != nil { return "", "", err } @@ -50,7 +50,7 @@ func (db *database) GetPin(guild_id string, message_id string) (string, string, // Retrieve pin_id associated with message_id pin_id := "" pin_channel_id := "" - err := db.Instance.QueryRowContext( + err = db.Instance.QueryRowContext( context.Background(), "SELECT pin_channel_id, pin_id FROM pins_" + guild_id + " WHERE message_id = ?", message_id, ).Scan(&pin_channel_id, &pin_id) diff --git a/src/database/stats.go b/src/database/stats.go index cc003be..bb5bcd9 100644 --- a/src/database/stats.go +++ b/src/database/stats.go @@ -24,7 +24,7 @@ func (db *database) createStatsTable(guild_id string) error { emoji_id TEXT NOT NULL ) `, guild_id) - _, err = db.Instance.ExecContext(context.Background(), query, guild_id) + _, err := db.Instance.ExecContext(context.Background(), query, guild_id) if err != nil { return fmt.Errorf("Failed to create stats_%s table: %w", guild_id, err) } @@ -34,7 +34,7 @@ func (db *database) createStatsTable(guild_id string) error { // AddStat inserts a statistic into the guild_id's stats table. func (db *database) AddStats(guild_id string, user_id string, emoji_id string) error { // Create table if it doesn't exist - err = db.createStatsTable(guild_id) + err := db.createStatsTable(guild_id) if err != nil { return err } @@ -51,7 +51,7 @@ func (db *database) AddStats(guild_id string, user_id string, emoji_id string) e // GetStats returns the total number of pins, with specific emojis used, a user has received in a guild. func (db *database) GetStats(guild_id string, user_id string) (*UserStats, error) { // Create guild pins table if it doesn't exist - err = db.createStatsTable(guild_id) + err := db.createStatsTable(guild_id) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func (db *database) GetStats(guild_id string, user_id string) (*UserStats, error SELECT COUNT(*) FROM stats_%s WHERE user_id = ?`, guild_id) - err := db.Instance.QueryRowContext(context.Background(), query, user_id).Scan(&count) + err = db.Instance.QueryRowContext(context.Background(), query, user_id).Scan(&count) if err != nil { return nil, err } @@ -98,7 +98,7 @@ func (db *database) GetStats(guild_id string, user_id string) (*UserStats, error // GetLeaderboard returns the top ten users in a guild with the most total number of pins, with specific emojis used. func (db *database) GetLeaderboard(guild_id string) ([]*UserStats, error) { // Create guild pins table if it doesn't exist - err = db.createStatsTable(guild_id) + err := db.createStatsTable(guild_id) if err != nil { return nil, err } diff --git a/src/database/webhooks.go b/src/database/webhooks.go index 2ce78ce..cd533cb 100644 --- a/src/database/webhooks.go +++ b/src/database/webhooks.go @@ -15,7 +15,7 @@ func (db *database) createWebhookTable() error { PRIMARY KEY (guild_id) ) ` - _, err = db.Instance.ExecContext(context.Background(), query) + _, err := db.Instance.ExecContext(context.Background(), query) if err != nil { return fmt.Errorf("Failed to create webhooks table: %w", err) } diff --git a/src/events/main.go b/src/events/main.go index a351d82..580dfbe 100644 --- a/src/events/main.go +++ b/src/events/main.go @@ -15,8 +15,5 @@ func RegisterAll(discord *discordgo.Session) { discord.AddHandler(onReaction) discord.AddHandler(onReactionRemove) discord.AddHandler(onMessageDelete) - discord.AddHandler(onReady) - discord.AddHandler(onChannelUpdate) - discord.AddHandler(onChannelDelete) discord.AddHandler(onPin) } diff --git a/src/events/on_pin.go b/src/events/on_pin.go index f21358c..58f9a7c 100644 --- a/src/events/on_pin.go +++ b/src/events/on_pin.go @@ -1,15 +1,30 @@ package events import ( - "github.com/bwmarrin/discordgo" "log" + "sync" + "github.com/bwmarrin/discordgo" "github.com/jadc/redpin/misc" ) // Hashmap of channel id to pin count // Used to prevent attempting to pin when a message is unpinned var counts = make(map[string]int) +var countsMu sync.Mutex + +// hasPinCountIncreased updates the cached pin count for a channel +// and returns whether the new count is greater than the previous one. +func hasPinCountIncreased(channelID string, count int) bool { + countsMu.Lock() + defer countsMu.Unlock() + + prev, ok := counts[channelID] + counts[channelID] = count + + // If not cached (!ok) then pin anyway, might be a pin decrease + return !ok || count > prev +} func onPin(discord *discordgo.Session, event *discordgo.ChannelPinsUpdate) { // Get pinned messages in channel @@ -18,11 +33,7 @@ func onPin(discord *discordgo.Session, event *discordgo.ChannelPinsUpdate) { return } - // Abort if the pin event was removal - if _, ok := counts[event.ChannelID]; !ok { - counts[event.ChannelID] = len(pins) - } - if len(pins) < counts[event.ChannelID] { + if !hasPinCountIncreased(event.ChannelID, len(pins)) { return } diff --git a/src/events/on_reaction.go b/src/events/on_reaction.go index dd4c422..a20886e 100644 --- a/src/events/on_reaction.go +++ b/src/events/on_reaction.go @@ -1,17 +1,19 @@ package events import ( - "github.com/bwmarrin/discordgo" "log" + "sync" + "github.com/bwmarrin/discordgo" "github.com/jadc/redpin/database" "github.com/jadc/redpin/misc" ) -// Map of message id to reaction emoji id +// Map of message id to reaction emoji API name // Used to prevent self-pinning (if that is disabled) // Cached to reduce API calls var selfpin = make(map[string]map[string]struct{}) +var selfpinMu sync.RWMutex func onReaction(discord *discordgo.Session, event *discordgo.MessageReactionAdd) { reaction := event.MessageReaction @@ -29,18 +31,15 @@ func onReaction(discord *discordgo.Session, event *discordgo.MessageReactionAdd) // Update selfpin map on reaction add if reaction.UserID == message.Author.ID { + selfpinMu.Lock() if selfpin[event.MessageID] == nil { selfpin[event.MessageID] = make(map[string]struct{}) } - - selfpin[event.MessageID][event.Emoji.ID] = struct{}{} + selfpin[event.MessageID][event.Emoji.APIName()] = struct{}{} + selfpinMu.Unlock() } - db, err := database.Connect() - if err != nil { - log.Printf("Failed to connect to database: %v", err) - return - } + db := database.Connect() c := db.GetConfig(event.GuildID) // Ignore reactions in pin channel @@ -54,8 +53,11 @@ func onReaction(discord *discordgo.Session, event *discordgo.MessageReactionAdd) } // Ignore reactions in NSFW channels - if !c.NSFW && isNSFW(discord, reaction.ChannelID) { - return + if !c.NSFW { + channel, err := discord.State.Channel(reaction.ChannelID) + if err != nil || channel.NSFW { + return + } } if !shouldPin(c, message) { @@ -80,27 +82,34 @@ func onReaction(discord *discordgo.Session, event *discordgo.MessageReactionAdd) // Update selfpin map on reaction remove func onReactionRemove(discord *discordgo.Session, event *discordgo.MessageReactionRemove) { - if selfpin[event.MessageID] != nil { - reaction := event.MessageReaction - message, err := discord.ChannelMessage(reaction.ChannelID, reaction.MessageID) - if err != nil { - log.Printf("Failed to fetch message '%s': %v", reaction.MessageID, err) - return - } + selfpinMu.RLock() + _, exists := selfpin[event.MessageID] + selfpinMu.RUnlock() + if !exists { + return + } - if reaction.UserID == message.Author.ID { - if selfpin[event.MessageID] != nil { - delete(selfpin[event.MessageID], event.Emoji.ID) - } + reaction := event.MessageReaction + message, err := discord.ChannelMessage(reaction.ChannelID, reaction.MessageID) + if err != nil { + log.Printf("Failed to fetch message '%s': %v", reaction.MessageID, err) + return + } + + if reaction.UserID == message.Author.ID { + selfpinMu.Lock() + if selfpin[event.MessageID] != nil { + delete(selfpin[event.MessageID], event.Emoji.APIName()) } + selfpinMu.Unlock() } } // Update selfpin map on message delete func onMessageDelete(discord *discordgo.Session, event *discordgo.MessageDelete) { - if selfpin[event.Message.ID] != nil { - delete(selfpin, event.Message.ID) - } + selfpinMu.Lock() + delete(selfpin, event.Message.ID) + selfpinMu.Unlock() } // shouldPin checks all reactions of the messsage, and determines if the message should be pinned. @@ -118,7 +127,10 @@ func shouldPin(c *database.Config, message *discordgo.Message) bool { // Remove reactions from the message author from the count if !c.Selfpin { - if _, ok := selfpin[message.ID][r.Emoji.ID]; ok { + selfpinMu.RLock() + _, isSelfpin := selfpin[message.ID][r.Emoji.APIName()] + selfpinMu.RUnlock() + if isSelfpin { count-- } } @@ -132,37 +144,3 @@ func shouldPin(c *database.Config, message *discordgo.Message) bool { return false } -// Hashset of ids of NSFW channels -// Cached to reduce API calls -var is_nsfw = make(map[string]bool) - -// isNSFW returns whether a channel is NSFW or not, reading from cache when possible -func isNSFW(discord *discordgo.Session, channel_id string) bool { - // Attempt to read from cache - if nsfw, ok := is_nsfw[channel_id]; ok { - return nsfw - } - - // Otherwise, query from Discord - channel, _ := discord.Channel(channel_id) - is_nsfw[channel_id] = channel.NSFW - return is_nsfw[channel_id] -} - -func onReady(discord *discordgo.Session, event *discordgo.Ready) { - for _, guild := range event.Guilds { - channels, _ := discord.GuildChannels(guild.ID) - for _, channel := range channels { - is_nsfw[channel.ID] = channel.NSFW - } - } - log.Printf("Cached %d channels' NSFW status", len(is_nsfw)) -} - -func onChannelUpdate(discord *discordgo.Session, event *discordgo.ChannelUpdate) { - is_nsfw[event.Channel.ID] = event.NSFW -} - -func onChannelDelete(discord *discordgo.Session, event *discordgo.ChannelDelete) { - delete(is_nsfw, event.Channel.ID) -} diff --git a/src/misc/pin.go b/src/misc/pin.go index 6b5bf10..dc69cd5 100644 --- a/src/misc/pin.go +++ b/src/misc/pin.go @@ -1,12 +1,15 @@ package misc import ( + "bytes" "database/sql" "errors" "fmt" - "log" - "net/http" - "strings" + "io" + "log" + "net/http" + "strings" + "sync" "github.com/bwmarrin/discordgo" "github.com/jadc/redpin/database" @@ -33,6 +36,27 @@ type PinRequest struct { // Hashset of messages currently being pinned // Helps prevent rapid reactions from pinning a message twice var pinning = make(map[string]struct{}) +var pinningMu sync.Mutex + +// startPinning marks a message as currently being pinned. +// Returns false if it was already being pinned. +func startPinning(messageID string) bool { + pinningMu.Lock() + defer pinningMu.Unlock() + + if _, ok := pinning[messageID]; ok { + return false + } + pinning[messageID] = struct{}{} + return true +} + +// donePinning removes a message from the currently-pinning set. +func donePinning(messageID string) { + pinningMu.Lock() + delete(pinning, messageID) + pinningMu.Unlock() +} // CreatePinRequest creates a copy of the message, and all messages it references, in its current state func CreatePinRequest(discord *discordgo.Session, guild_id string, message *discordgo.Message) (*PinRequest, error) { @@ -42,16 +66,12 @@ func CreatePinRequest(discord *discordgo.Session, guild_id string, message *disc } // Skip messages currently being pinned - if _, ok := pinning[message.ID]; ok { + if !startPinning(message.ID) { return nil, ALREADY_PINNED } - pinning[message.ID] = struct{}{} // Retrieve current config - db, err := database.Connect() - if err != nil { - return nil, fmt.Errorf("Failed to connect to database: %v", err) - } + db := database.Connect() c := db.GetConfig(guild_id) // Create pin request @@ -87,10 +107,9 @@ func CreatePinRequest(discord *discordgo.Session, guild_id string, message *disc // Execute on a PinRequest pins the message, forwarding it to the pin channel // Returns the used pin channel ID and pin message's ID if successful func (req *PinRequest) Execute(discord *discordgo.Session) (string, string, error) { - db, err := database.Connect() - if err != nil { - return "", "", fmt.Errorf("Failed to connect to database: %v", err) - } + defer donePinning(req.message.ID) + + db := database.Connect() // Query database for if message is already pinned pin_channel_id, pin_msg_id, err := db.GetPin(req.guildID, req.message.ID) @@ -165,7 +184,6 @@ func (req *PinRequest) Execute(discord *discordgo.Session) (string, string, erro if err != nil { return "", "", fmt.Errorf("Failed to add pin to database: %v", err) } - delete(pinning, req.message.ID) log.Printf("Pinned message '%s' in guild '%s'", req.message.ID, req.guildID) return pin_msg.ChannelID, pin_msg.ID, nil @@ -288,10 +306,10 @@ func splitAttachments(attachments []*discordgo.MessageAttachment, size_limit int } if a.Size > 0 && a.Size < size_limit { - // Download attachment - data, err := http.DefaultClient.Get(a.URL) + // Download attachment, falling back to proxy URL + body, err := downloadAttachment(a.URL) if err != nil { - data, err = http.DefaultClient.Get(a.ProxyURL) + body, err = downloadAttachment(a.ProxyURL) if err != nil { // Append link instead if downloading attachment fails links = append(links, a.URL) @@ -303,13 +321,14 @@ func splitAttachments(attachments []*discordgo.MessageAttachment, size_limit int if len(files) >= MAX_FILES || size + a.Size >= size_limit { file_sets = append(file_sets, files) files = make([]*discordgo.File, 0, MAX_FILES) + size = 0 } - // Create file with attachment data + // Create file with buffered attachment data file := &discordgo.File{ Name: a.Filename, ContentType: a.ContentType, - Reader: data.Body, + Reader: bytes.NewReader(body), } files = append(files, file) size += a.Size @@ -326,10 +345,25 @@ func splitAttachments(attachments []*discordgo.MessageAttachment, size_limit int return file_sets, link_sets } +// downloadAttachment fetches a URL and returns the body as a byte slice, closing the response body immediately. +func downloadAttachment(url string) ([]byte, error) { + resp, err := http.DefaultClient.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status %d for %s", resp.StatusCode, url) + } + + return io.ReadAll(resp.Body) +} + // sizeLimit returns the maximum size (in bytes) of a message that can be sent in a guild func sizeLimit(discord *discordgo.Session, guild_id string) (int, error) { - // Get guild object - guild, err := discord.Guild(guild_id) + // Get (cached) guild object + guild, err := discord.State.Guild(guild_id) if err != nil { return 0, fmt.Errorf("Failed to retrieve guild '%s': %v", guild_id, err) } diff --git a/src/misc/webhook.go b/src/misc/webhook.go index 9196422..0288d37 100644 --- a/src/misc/webhook.go +++ b/src/misc/webhook.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "log" + "sync" "github.com/bwmarrin/discordgo" "github.com/jadc/redpin/database" @@ -25,80 +26,83 @@ type WebhookPair struct { // Technically, if the bot is restarted, the LRU bit is reset and // messages may get merged, but this is a very unlikely scenario var webhooks = make(map[string]*WebhookPair) +var webhooksMu sync.Mutex -// GetWebhook returns the appropriate webhook ID for a given guild +// GetWebhook returns the appropriate webhook for a given guild func GetWebhook(discord *discordgo.Session, guild_id string) (*discordgo.Webhook, error) { - db, err := database.Connect() - if err != nil { - return nil, fmt.Errorf("Failed to connect to database: %v", err) - } + webhooksMu.Lock() + defer webhooksMu.Unlock() + + db := database.Connect() c := db.GetConfig(guild_id) - // Attempt to read from cache if pair, ok := webhooks[guild_id]; ok { - // Create new webhooks, and delete the old ones, if expired + // Recreate cached webhooks if pin channel has changed if pair.WebhookA.ChannelID != c.Channel || pair.WebhookB.ChannelID != c.Channel { discord.WebhookDelete(pair.WebhookA.ID) discord.WebhookDelete(pair.WebhookB.ID) + var err error pair, err = createWebhook(discord, guild_id, c.Channel) if err != nil { return nil, err } } - // Alternate between the two webhooks - pair.LRU = !pair.LRU - if pair.LRU { - return pair.WebhookB, nil - } else { - return pair.WebhookA, nil - } + // Otherwise, use other webhook in cached pair + return alternateWebhook(pair), nil } - // Query database for an existing webhook + // Fetch webhook pair from database if not cached webhook_a_id, webhook_b_id, err := db.GetWebhook(guild_id) - - // Only throw up error if it's an actual error (not just row not found) if err != nil && err != sql.ErrNoRows { return nil, fmt.Errorf("Failed to retrieve webhook pair for guild '%s': %v", guild_id, err) } - // Create new webhook if one does not exist already + var pair *WebhookPair + if err == sql.ErrNoRows || len(webhook_a_id) == 0 || len(webhook_b_id) == 0 { - _, err := createWebhook(discord, guild_id, c.Channel) + // If no webhook pair in database, create new one + pair, err = createWebhook(discord, guild_id, c.Channel) if err != nil { return nil, err } - return GetWebhook(discord, guild_id) - } + } else { + // If webhook pair in databse, fetch webhook object and cache it + var webhook_a *discordgo.Webhook + var webhook_b *discordgo.Webhook + webhook_a, err = discord.Webhook(webhook_a_id) + if err == nil { + webhook_b, err = discord.Webhook(webhook_b_id) + } - // Retrieve webhook objects, create a new one if it's invalid - var webhook_a *discordgo.Webhook - var webhook_b *discordgo.Webhook - webhook_a, err = discord.Webhook(webhook_a_id) - if err == nil { - webhook_b, err = discord.Webhook(webhook_b_id) - } - if err != nil { - _, err := createWebhook(discord, guild_id, c.Channel) if err != nil { - return nil, err + // If stored webhook IDs are stale/deleted, create new pair + pair, err = createWebhook(discord, guild_id, c.Channel) + if err != nil { + return nil, err + } + } else { + pair = &WebhookPair{ WebhookA: webhook_a, WebhookB: webhook_b } + webhooks[guild_id] = pair } - return GetWebhook(discord, guild_id) } - // Cache ond return existing webhook pair - webhooks[guild_id] = &WebhookPair{ WebhookA: webhook_a, WebhookB: webhook_b } - return GetWebhook(discord, guild_id) + return alternateWebhook(pair), nil +} + +// alternateWebhook flips the LRU bit and returns the next webhook in the pair. +func alternateWebhook(pair *WebhookPair) *discordgo.Webhook { + pair.LRU = !pair.LRU + if pair.LRU { + return pair.WebhookB + } + return pair.WebhookA } // createWebhook creates a webhook pair for a given guild in the given pin channel func createWebhook(discord *discordgo.Session, guild_id string, channel_id string) (*WebhookPair, error) { - db, err := database.Connect() - if err != nil { - return nil, fmt.Errorf("Failed to connect to database: %v", err) - } + db := database.Connect() // Create webhook A in given channel webhookA, err := discord.WebhookCreate(channel_id, "redpin A", "")