Skip to content
54 changes: 13 additions & 41 deletions src/commands/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,7 @@ var command_config_main = Command{
metadata: nil,
handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) {
// Fetch config for this guild
db, err := database.Connect()
if err != nil {
log.Printf("Failed to connect to database: %v", err)
return
}
db := database.Connect()
c := db.GetConfig(i.GuildID)

j, err := json.MarshalIndent(c, "", " ")
Expand Down Expand Up @@ -83,18 +79,14 @@ var command_config_channel = Command{
},
handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) {
// Fetch config for this guild
db, err := database.Connect()
if err != nil {
log.Printf("Failed to connect to database: %v", err)
return
}
db := database.Connect()
c := db.GetConfig(i.GuildID)

// Write changes to config and save it
new_value := i.ApplicationCommandData().Options[option].ChannelValue(discord).ID
if c.Channel != new_value {
c.Channel = new_value
err = db.SaveConfig(i.GuildID, c)
err := db.SaveConfig(i.GuildID, c)
if err != nil {
log.Printf("Failed to save config: %v", err)
return
Expand Down Expand Up @@ -124,18 +116,14 @@ var command_config_threshold = Command{
},
handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) {
// Fetch config for this guild
db, err := database.Connect()
if err != nil {
log.Printf("Failed to connect to database: %v", err)
return
}
db := database.Connect()
c := db.GetConfig(i.GuildID)

// Write changes to config and save it
new_value := int(i.ApplicationCommandData().Options[option].IntValue())
if c.Threshold != new_value {
c.Threshold = new_value
err = db.SaveConfig(i.GuildID, c)
err := db.SaveConfig(i.GuildID, c)
if err != nil {
log.Printf("Failed to save config: %v", err)
return
Expand Down Expand Up @@ -163,18 +151,14 @@ var command_config_nsfw = Command{
},
handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) {
// Fetch config for this guild
db, err := database.Connect()
if err != nil {
log.Printf("Failed to connect to database: %v", err)
return
}
db := database.Connect()
c := db.GetConfig(i.GuildID)

// Write changes to config and save it
new_value := i.ApplicationCommandData().Options[option].BoolValue()
if c.NSFW != new_value {
c.NSFW = new_value
err = db.SaveConfig(i.GuildID, c)
err := db.SaveConfig(i.GuildID, c)
if err != nil {
log.Printf("Failed to save config: %v", err)
return
Expand Down Expand Up @@ -208,18 +192,14 @@ var command_config_selfpin = Command{
},
handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) {
// Fetch config for this guild
db, err := database.Connect()
if err != nil {
log.Printf("Failed to connect to database: %v", err)
return
}
db := database.Connect()
c := db.GetConfig(i.GuildID)

// Write changes to config and save it
new_value := i.ApplicationCommandData().Options[option].BoolValue()
if c.Selfpin != new_value {
c.Selfpin = new_value
err = db.SaveConfig(i.GuildID, c)
err := db.SaveConfig(i.GuildID, c)
if err != nil {
log.Printf("Failed to save config: %v", err)
return
Expand Down Expand Up @@ -256,18 +236,14 @@ var command_config_replydepth = Command{
},
handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) {
// Fetch config for this guild
db, err := database.Connect()
if err != nil {
log.Printf("Failed to connect to database: %v", err)
return
}
db := database.Connect()
c := db.GetConfig(i.GuildID)

// Write changes to config and save it
new_value := int(i.ApplicationCommandData().Options[option].IntValue())
if c.ReplyDepth != new_value {
c.ReplyDepth = new_value
err = db.SaveConfig(i.GuildID, c)
err := db.SaveConfig(i.GuildID, c)
if err != nil {
log.Printf("Failed to save config: %v", err)
return
Expand Down Expand Up @@ -295,11 +271,7 @@ var command_config_emoji = Command{
},
handler: func(discord *discordgo.Session, option int, i *discordgo.InteractionCreate) {
// Fetch config for this guild
db, err := database.Connect()
if err != nil {
log.Printf("Failed to connect to database: %v", err)
return
}
db := database.Connect()
c := db.GetConfig(i.GuildID)

// Write changes to config and save it
Expand All @@ -316,7 +288,7 @@ var command_config_emoji = Command{
c.Allowlist[emoji] = struct{}{}
}

err = db.SaveConfig(i.GuildID, c)
err := db.SaveConfig(i.GuildID, c)
if err != nil {
log.Printf("Failed to save config: %v", err)
return
Expand Down
10 changes: 2 additions & 8 deletions src/commands/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ var command_stats_leaderboard = Command{
})

// Connect to database
db, err := database.Connect()
if err != nil {
log.Printf("Failed to connect to database: %v", err)
}
db := database.Connect()

lb, err := db.GetLeaderboard(i.GuildID)
if err != nil {
Expand Down Expand Up @@ -116,10 +113,7 @@ var command_stats_user = Command{

if user := i.ApplicationCommandData().Options[0].UserValue(discord); user != nil {
// Connect to database
db, err := database.Connect()
if err != nil {
log.Printf("Failed to connect to database: %v", err)
}
db := database.Connect()

// Build embed header
if member, err := discord.GuildMember(i.GuildID, user.ID); err == nil {
Expand Down
8 changes: 4 additions & 4 deletions src/database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (db *database) createConfigTable() error {
PRIMARY KEY (guild_id)
)
`)
_, err = db.Instance.ExecContext(context.Background(), query)
_, err := db.Instance.ExecContext(context.Background(), query)
if err != nil {
return fmt.Errorf("Failed to create config table: %w", err)
}
Expand Down Expand Up @@ -75,13 +75,13 @@ func (db *database) LoadConfig(guild_id string) (*Config, error) {

// Retrieve config data from database, load default if not found
var raw string
err = db.Instance.QueryRow("SELECT json FROM config WHERE guild_id = ?", guild_id).Scan(&raw)
err := db.Instance.QueryRow("SELECT json FROM config WHERE guild_id = ?", guild_id).Scan(&raw)
if err != nil {
return c, fmt.Errorf("Failed to retrieve config for guild '%s': %w", guild_id, err)
}

// Load config into new object
err := json.Unmarshal([]byte(raw), c)
err = json.Unmarshal([]byte(raw), c)
if err != nil {
return c, fmt.Errorf("Failed to unmarshal config for guild '%s': %w", guild_id, err)
}
Expand All @@ -94,7 +94,7 @@ func (db *database) SaveConfig(guild_id string, c *Config) error {
log.Printf("Saving config for guild '%s'\n", guild_id)

// Create guild pins table if it doesn't exist
err = db.createConfigTable()
err := db.createConfigTable()
if err != nil {
return fmt.Errorf("Failed to create config table: %w", err)
}
Expand Down
12 changes: 6 additions & 6 deletions src/database/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ type database struct {
var db *database

var once sync.Once
var err error

// Connects to database and returns a pointer to it
func Connect() (*database, error) {
func Connect() *database {
// Only runs once
once.Do(func(){
// Retrieve file name from the environment
Expand All @@ -39,11 +38,12 @@ func Connect() (*database, error) {
// Limit writes to one at a time (not ideal)
instance.SetMaxOpenConns(1)

// Assign to global variable
err = instance.Ping()
if err := instance.Ping(); err != nil {
log.Fatal("Failed to ping database: ", err)
}

db = &database{instance, make(map[string]*Config)}
})

// Returns current or new sqlite3 instance (after testing with ping)
return db, err
return db
}
8 changes: 4 additions & 4 deletions src/database/msgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func (db *database) createPinTable(guild_id string) error {
PRIMARY KEY (message_id, pin_channel_id, pin_id)
)
`, guild_id)
_, err = db.Instance.ExecContext(context.Background(), query, guild_id)
_, err := db.Instance.ExecContext(context.Background(), query, guild_id)
if err != nil {
return fmt.Errorf("Failed to create pins_%s table: %w", guild_id, err)
}
Expand All @@ -25,7 +25,7 @@ func (db *database) createPinTable(guild_id string) error {
// AddPin inserts a message_id -> pin_id pair into the guild_id table.
func (db *database) AddPin(guild_id string, pin_channel_id string, message_id string, pin_id string) error {
// Create guild pins table if it doesn't exist
err = db.createPinTable(guild_id)
err := db.createPinTable(guild_id)
if err != nil {
return err
}
Expand All @@ -42,15 +42,15 @@ func (db *database) AddPin(guild_id string, pin_channel_id string, message_id st
// GetPin retrieves the pin message id from the guild_id table given a guild_id and message_id.
func (db *database) GetPin(guild_id string, message_id string) (string, string, error) {
// Create guild pins table if it doesn't exist
err = db.createPinTable(guild_id)
err := db.createPinTable(guild_id)
if err != nil {
return "", "", err
}

// Retrieve pin_id associated with message_id
pin_id := ""
pin_channel_id := ""
err := db.Instance.QueryRowContext(
err = db.Instance.QueryRowContext(
context.Background(),
"SELECT pin_channel_id, pin_id FROM pins_" + guild_id + " WHERE message_id = ?", message_id,
).Scan(&pin_channel_id, &pin_id)
Expand Down
10 changes: 5 additions & 5 deletions src/database/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (db *database) createStatsTable(guild_id string) error {
emoji_id TEXT NOT NULL
)
`, guild_id)
_, err = db.Instance.ExecContext(context.Background(), query, guild_id)
_, err := db.Instance.ExecContext(context.Background(), query, guild_id)
if err != nil {
return fmt.Errorf("Failed to create stats_%s table: %w", guild_id, err)
}
Expand All @@ -34,7 +34,7 @@ func (db *database) createStatsTable(guild_id string) error {
// AddStat inserts a statistic into the guild_id's stats table.
func (db *database) AddStats(guild_id string, user_id string, emoji_id string) error {
// Create table if it doesn't exist
err = db.createStatsTable(guild_id)
err := db.createStatsTable(guild_id)
if err != nil {
return err
}
Expand All @@ -51,7 +51,7 @@ func (db *database) AddStats(guild_id string, user_id string, emoji_id string) e
// GetStats returns the total number of pins, with specific emojis used, a user has received in a guild.
func (db *database) GetStats(guild_id string, user_id string) (*UserStats, error) {
// Create guild pins table if it doesn't exist
err = db.createStatsTable(guild_id)
err := db.createStatsTable(guild_id)
if err != nil {
return nil, err
}
Expand All @@ -64,7 +64,7 @@ func (db *database) GetStats(guild_id string, user_id string) (*UserStats, error
SELECT COUNT(*)
FROM stats_%s
WHERE user_id = ?`, guild_id)
err := db.Instance.QueryRowContext(context.Background(), query, user_id).Scan(&count)
err = db.Instance.QueryRowContext(context.Background(), query, user_id).Scan(&count)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -98,7 +98,7 @@ func (db *database) GetStats(guild_id string, user_id string) (*UserStats, error
// GetLeaderboard returns the top ten users in a guild with the most total number of pins, with specific emojis used.
func (db *database) GetLeaderboard(guild_id string) ([]*UserStats, error) {
// Create guild pins table if it doesn't exist
err = db.createStatsTable(guild_id)
err := db.createStatsTable(guild_id)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion src/database/webhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func (db *database) createWebhookTable() error {
PRIMARY KEY (guild_id)
)
`
_, err = db.Instance.ExecContext(context.Background(), query)
_, err := db.Instance.ExecContext(context.Background(), query)
if err != nil {
return fmt.Errorf("Failed to create webhooks table: %w", err)
}
Expand Down
3 changes: 0 additions & 3 deletions src/events/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,5 @@ func RegisterAll(discord *discordgo.Session) {
discord.AddHandler(onReaction)
discord.AddHandler(onReactionRemove)
discord.AddHandler(onMessageDelete)
discord.AddHandler(onReady)
discord.AddHandler(onChannelUpdate)
discord.AddHandler(onChannelDelete)
discord.AddHandler(onPin)
}
23 changes: 17 additions & 6 deletions src/events/on_pin.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
package events

import (
"github.com/bwmarrin/discordgo"
"log"
"sync"

"github.com/bwmarrin/discordgo"
"github.com/jadc/redpin/misc"
)

// Hashmap of channel id to pin count
// Used to prevent attempting to pin when a message is unpinned
var counts = make(map[string]int)
var countsMu sync.Mutex

// hasPinCountIncreased updates the cached pin count for a channel
// and returns whether the new count is greater than the previous one.
func hasPinCountIncreased(channelID string, count int) bool {
countsMu.Lock()
defer countsMu.Unlock()

prev, ok := counts[channelID]
counts[channelID] = count

// If not cached (!ok) then pin anyway, might be a pin decrease
return !ok || count > prev
}

func onPin(discord *discordgo.Session, event *discordgo.ChannelPinsUpdate) {
// Get pinned messages in channel
Expand All @@ -18,11 +33,7 @@ func onPin(discord *discordgo.Session, event *discordgo.ChannelPinsUpdate) {
return
}

// Abort if the pin event was removal
if _, ok := counts[event.ChannelID]; !ok {
counts[event.ChannelID] = len(pins)
}
if len(pins) < counts[event.ChannelID] {
if !hasPinCountIncreased(event.ChannelID, len(pins)) {
return
}

Expand Down
Loading
Loading