diff --git a/cmd/msgvault/cmd/addo365.go b/cmd/msgvault/cmd/addo365.go new file mode 100644 index 00000000..b39314a2 --- /dev/null +++ b/cmd/msgvault/cmd/addo365.go @@ -0,0 +1,109 @@ +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" + imapclient "github.com/wesm/msgvault/internal/imap" + "github.com/wesm/msgvault/internal/microsoft" + "github.com/wesm/msgvault/internal/store" +) + +var o365TenantID string + +var addO365Cmd = &cobra.Command{ + Use: "add-o365 ", + Short: "Add a Microsoft 365 account via OAuth", + Long: `Add a Microsoft 365 / Outlook.com email account using OAuth2 authentication. + +This opens a browser for Microsoft authorization, then configures IMAP access +to outlook.office365.com automatically using the XOAUTH2 SASL mechanism. + +Requires a [microsoft] section in config.toml with your Azure AD app's client_id. +See the docs for Azure AD app registration setup. + +Examples: + msgvault add-o365 user@outlook.com + msgvault add-o365 user@company.com --tenant my-tenant-id`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + email := args[0] + + if cfg.Microsoft.ClientID == "" { + return fmt.Errorf("Microsoft OAuth not configured.\n\n" + + "Add to your config.toml:\n\n" + + " [microsoft]\n" + + " client_id = \"your-azure-app-client-id\"\n\n" + + "See docs for Azure AD app registration setup.") + } + + tenantID := cfg.Microsoft.EffectiveTenantID() + if o365TenantID != "" { + tenantID = o365TenantID + } + + msMgr := microsoft.NewManager( + cfg.Microsoft.ClientID, + tenantID, + cfg.TokensDir(), + logger, + ) + + fmt.Printf("Authorizing %s with Microsoft...\n", email) + if err := msMgr.Authorize(cmd.Context(), email); err != nil { + return fmt.Errorf("authorization failed: %w", err) + } + + // Auto-configure IMAP for outlook.office365.com + imapCfg := &imapclient.Config{ + Host: "outlook.office365.com", + Port: 993, + TLS: true, + Username: email, + AuthMethod: imapclient.AuthXOAuth2, + } + + dbPath := cfg.DatabaseDSN() + s, err := store.Open(dbPath) + if err != nil { + return fmt.Errorf("open database: %w", err) + } + defer s.Close() + + if err := s.InitSchema(); err != nil { + return fmt.Errorf("init schema: %w", err) + } + + identifier := imapCfg.Identifier() + source, err := s.GetOrCreateSource("imap", identifier) + if err != nil { + return fmt.Errorf("create source: %w", err) + } + + cfgJSON, err := imapCfg.ToJSON() + if err != nil { + return fmt.Errorf("serialize config: %w", err) + } + if err := s.UpdateSourceSyncConfig(source.ID, cfgJSON); err != nil { + return fmt.Errorf("store config: %w", err) + } + if err := s.UpdateSourceDisplayName(source.ID, email); err != nil { + return fmt.Errorf("set display name: %w", err) + } + + fmt.Printf("\nMicrosoft 365 account added successfully!\n") + fmt.Printf(" Email: %s\n", email) + fmt.Printf(" Identifier: %s\n", identifier) + fmt.Println() + fmt.Println("You can now run:") + fmt.Printf(" msgvault sync-full %s\n", identifier) + + return nil + }, +} + +func init() { + addO365Cmd.Flags().StringVar(&o365TenantID, "tenant", "", + "Azure AD tenant ID (default: \"common\" for multi-tenant)") + rootCmd.AddCommand(addO365Cmd) +} diff --git a/cmd/msgvault/cmd/remove_account.go b/cmd/msgvault/cmd/remove_account.go index d38fa4d5..9fc0e908 100644 --- a/cmd/msgvault/cmd/remove_account.go +++ b/cmd/msgvault/cmd/remove_account.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" imaplib "github.com/wesm/msgvault/internal/imap" + "github.com/wesm/msgvault/internal/microsoft" "github.com/wesm/msgvault/internal/oauth" "github.com/wesm/msgvault/internal/store" ) @@ -139,6 +140,23 @@ func runRemoveAccount(cmd *cobra.Command, args []string) error { credPath, err, ) } + // Also clean up Microsoft OAuth token if this was an XOAUTH2 source + if source.SyncConfig.Valid && source.SyncConfig.String != "" { + imapCfg, parseErr := imaplib.ConfigFromJSON(source.SyncConfig.String) + if parseErr == nil && imapCfg.EffectiveAuthMethod() == imaplib.AuthXOAuth2 { + msMgr := microsoft.NewManager( + cfg.Microsoft.ClientID, + cfg.Microsoft.EffectiveTenantID(), + cfg.TokensDir(), + logger, + ) + if err := msMgr.DeleteToken(imapCfg.Username); err != nil { + fmt.Fprintf(os.Stderr, + "Warning: could not remove Microsoft token: %v\n", err, + ) + } + } + } } // Remove analytics cache (shared across accounts, needs full rebuild) diff --git a/cmd/msgvault/cmd/sync.go b/cmd/msgvault/cmd/sync.go index fc258eb0..667b3532 100644 --- a/cmd/msgvault/cmd/sync.go +++ b/cmd/msgvault/cmd/sync.go @@ -14,6 +14,7 @@ import ( "github.com/spf13/cobra" "github.com/wesm/msgvault/internal/gmail" imaplib "github.com/wesm/msgvault/internal/imap" + "github.com/wesm/msgvault/internal/microsoft" "github.com/wesm/msgvault/internal/oauth" "github.com/wesm/msgvault/internal/store" "github.com/wesm/msgvault/internal/sync" @@ -145,8 +146,21 @@ Examples: } gmailTargets = append(gmailTargets, syncTarget{source: src, email: src.Identifier}) case "imap": - if !imaplib.HasCredentials(cfg.TokensDir(), src.Identifier) { - fmt.Printf("Skipping %s (no credentials - run 'add-imap' first)\n", src.Identifier) + hasAuth := imaplib.HasCredentials(cfg.TokensDir(), src.Identifier) + if !hasAuth && src.SyncConfig.Valid { + imapCfg, parseErr := imaplib.ConfigFromJSON(src.SyncConfig.String) + if parseErr == nil && imapCfg.EffectiveAuthMethod() == imaplib.AuthXOAuth2 { + msMgr := microsoft.NewManager( + cfg.Microsoft.ClientID, + cfg.Microsoft.EffectiveTenantID(), + cfg.TokensDir(), + logger, + ) + hasAuth = msMgr.HasToken(imapCfg.Username) + } + } + if !hasAuth { + fmt.Printf("Skipping %s (no credentials - run 'add-imap' or 'add-o365' first)\n", src.Identifier) continue } imapTargets = append(imapTargets, src) diff --git a/cmd/msgvault/cmd/syncfull.go b/cmd/msgvault/cmd/syncfull.go index ecba7d21..13351b04 100644 --- a/cmd/msgvault/cmd/syncfull.go +++ b/cmd/msgvault/cmd/syncfull.go @@ -13,6 +13,7 @@ import ( "github.com/spf13/cobra" "github.com/wesm/msgvault/internal/gmail" imaplib "github.com/wesm/msgvault/internal/imap" + "github.com/wesm/msgvault/internal/microsoft" "github.com/wesm/msgvault/internal/oauth" "github.com/wesm/msgvault/internal/store" "github.com/wesm/msgvault/internal/sync" @@ -133,8 +134,21 @@ Examples: continue } case "imap": - if !imaplib.HasCredentials(cfg.TokensDir(), src.Identifier) { - fmt.Printf("Skipping %s (no credentials - run 'add-imap' first)\n", src.Identifier) + hasAuth := imaplib.HasCredentials(cfg.TokensDir(), src.Identifier) + if !hasAuth && src.SyncConfig.Valid && src.SyncConfig.String != "" { + imapCfg, parseErr := imaplib.ConfigFromJSON(src.SyncConfig.String) + if parseErr == nil && imapCfg.EffectiveAuthMethod() == imaplib.AuthXOAuth2 { + msMgr := microsoft.NewManager( + cfg.Microsoft.ClientID, + cfg.Microsoft.EffectiveTenantID(), + cfg.TokensDir(), + logger, + ) + hasAuth = msMgr.HasToken(imapCfg.Username) + } + } + if !hasAuth { + fmt.Printf("Skipping %s (no credentials - run 'add-imap' or 'add-o365' first)\n", src.Identifier) continue } default: @@ -220,11 +234,31 @@ func buildAPIClient(ctx context.Context, src *store.Source, oauthMgr *oauth.Mana if err != nil { return nil, fmt.Errorf("parse IMAP config: %w", err) } - password, err := imaplib.LoadCredentials(cfg.TokensDir(), src.Identifier) - if err != nil { - return nil, fmt.Errorf("load IMAP credentials: %w (run 'add-imap' first)", err) + + var opts []imaplib.Option + opts = append(opts, imaplib.WithLogger(logger)) + + switch imapCfg.EffectiveAuthMethod() { + case imaplib.AuthXOAuth2: + msMgr := microsoft.NewManager( + cfg.Microsoft.ClientID, + cfg.Microsoft.EffectiveTenantID(), + cfg.TokensDir(), + logger, + ) + tokenFn, err := msMgr.TokenSource(ctx, imapCfg.Username) + if err != nil { + return nil, fmt.Errorf("load Microsoft token: %w (run 'add-o365' first)", err) + } + opts = append(opts, imaplib.WithTokenSource(tokenFn)) + return imaplib.NewClient(imapCfg, "", opts...), nil + default: + password, err := imaplib.LoadCredentials(cfg.TokensDir(), src.Identifier) + if err != nil { + return nil, fmt.Errorf("load IMAP credentials: %w (run 'add-imap' first)", err) + } + return imaplib.NewClient(imapCfg, password, opts...), nil } - return imaplib.NewClient(imapCfg, password, imaplib.WithLogger(logger)), nil default: return nil, fmt.Errorf("unsupported source type %q", src.SourceType) diff --git a/go.mod b/go.mod index 866e666e..27e93de9 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/charmbracelet/lipgloss v1.1.0 github.com/charmbracelet/x/ansi v0.11.6 github.com/emersion/go-imap/v2 v2.0.0-beta.8 + github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 github.com/go-chi/chi/v5 v5.2.5 github.com/gogs/chardet v0.0.0-20211120154057-b7413eaefb8f github.com/google/go-cmp v0.7.0 @@ -45,7 +46,6 @@ require ( github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.5.0 // indirect github.com/emersion/go-message v0.18.2 // indirect - github.com/emersion/go-sasl v0.0.0-20241020182733-b788ff22d5a6 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/go-viper/mapstructure/v2 v2.3.0 // indirect github.com/goccy/go-json v0.10.5 // indirect diff --git a/internal/config/config.go b/internal/config/config.go index 84de5a56..ad790931 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -70,13 +70,14 @@ type RemoteConfig struct { // Config represents the msgvault configuration. type Config struct { - Data DataConfig `toml:"data"` - OAuth OAuthConfig `toml:"oauth"` - Sync SyncConfig `toml:"sync"` - Chat ChatConfig `toml:"chat"` - Server ServerConfig `toml:"server"` - Remote RemoteConfig `toml:"remote"` - Accounts []AccountSchedule `toml:"accounts"` + Data DataConfig `toml:"data"` + OAuth OAuthConfig `toml:"oauth"` + Microsoft MicrosoftConfig `toml:"microsoft"` + Sync SyncConfig `toml:"sync"` + Chat ChatConfig `toml:"chat"` + Server ServerConfig `toml:"server"` + Remote RemoteConfig `toml:"remote"` + Accounts []AccountSchedule `toml:"accounts"` // Computed paths (not from config file) HomeDir string `toml:"-"` @@ -94,6 +95,21 @@ type OAuthConfig struct { ClientSecrets string `toml:"client_secrets"` } +// MicrosoftConfig holds Microsoft 365 / Azure AD OAuth configuration. +type MicrosoftConfig struct { + ClientID string `toml:"client_id"` + TenantID string `toml:"tenant_id"` +} + +// EffectiveTenantID returns the tenant ID, defaulting to "common" +// (multi-tenant, works for personal + org accounts). +func (c *MicrosoftConfig) EffectiveTenantID() string { + if c.TenantID == "" { + return "common" + } + return c.TenantID +} + // SyncConfig holds sync-related configuration. type SyncConfig struct { RateLimitQPS int `toml:"rate_limit_qps"` diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 6109dca3..792c0382 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1168,3 +1168,32 @@ func TestSave_AllowInsecureRoundTrip(t *testing.T) { t.Error("AllowInsecure should be true after saving with true") } } + +func TestMicrosoftConfig(t *testing.T) { + tmpDir := t.TempDir() + configContent := ` +[microsoft] +client_id = "test-client-id-123" +tenant_id = "my-tenant" +` + configPath := filepath.Join(tmpDir, "config.toml") + os.WriteFile(configPath, []byte(configContent), 0644) + + cfg, err := Load(configPath, tmpDir) + if err != nil { + t.Fatal(err) + } + if cfg.Microsoft.ClientID != "test-client-id-123" { + t.Errorf("Microsoft.ClientID = %q, want %q", cfg.Microsoft.ClientID, "test-client-id-123") + } + if cfg.Microsoft.TenantID != "my-tenant" { + t.Errorf("Microsoft.TenantID = %q, want %q", cfg.Microsoft.TenantID, "my-tenant") + } +} + +func TestMicrosoftConfig_DefaultTenant(t *testing.T) { + cfg := NewDefaultConfig() + if cfg.Microsoft.EffectiveTenantID() != "common" { + t.Errorf("EffectiveTenantID() = %q, want %q", cfg.Microsoft.EffectiveTenantID(), "common") + } +} diff --git a/internal/imap/client.go b/internal/imap/client.go index 45487eab..3d89bc4b 100644 --- a/internal/imap/client.go +++ b/internal/imap/client.go @@ -22,6 +22,12 @@ func WithLogger(logger *slog.Logger) Option { return func(c *Client) { c.logger = logger } } +// WithTokenSource sets a callback that provides OAuth2 access tokens +// for XOAUTH2 SASL authentication. Required when Config.AuthMethod is AuthXOAuth2. +func WithTokenSource(fn func(ctx context.Context) (string, error)) Option { + return func(c *Client) { c.tokenSource = fn } +} + // fetchChunkSize is the maximum number of UIDs per UID FETCH command. // Large FETCH sets cause server-side timeouts on big mailboxes; chunking // keeps each round-trip short. @@ -33,9 +39,10 @@ const listPageSize = 500 // Client implements gmail.API for IMAP servers. type Client struct { - config *Config - password string - logger *slog.Logger + config *Config + password string + tokenSource func(ctx context.Context) (string, error) // XOAUTH2 token callback + logger *slog.Logger mu sync.Mutex conn *imapclient.Client @@ -87,9 +94,27 @@ func (c *Client) connect(ctx context.Context) error { return fmt.Errorf("dial IMAP %s: %w", addr, err) } - if err := conn.Login(c.config.Username, c.password).Wait(); err != nil { - _ = conn.Close() - return fmt.Errorf("IMAP login: %w", err) + switch c.config.EffectiveAuthMethod() { + case AuthXOAuth2: + if c.tokenSource == nil { + _ = conn.Close() + return fmt.Errorf("XOAUTH2 auth requires a token source (use WithTokenSource)") + } + token, err := c.tokenSource(ctx) + if err != nil { + _ = conn.Close() + return fmt.Errorf("get XOAUTH2 token: %w", err) + } + saslClient := NewXOAuth2Client(c.config.Username, token) + if err := conn.Authenticate(saslClient); err != nil { + _ = conn.Close() + return fmt.Errorf("XOAUTH2 authenticate: %w", err) + } + default: + if err := conn.Login(c.config.Username, c.password).Wait(); err != nil { + _ = conn.Close() + return fmt.Errorf("IMAP login: %w", err) + } } c.conn = conn diff --git a/internal/imap/client_xoauth2_test.go b/internal/imap/client_xoauth2_test.go new file mode 100644 index 00000000..cfb77c62 --- /dev/null +++ b/internal/imap/client_xoauth2_test.go @@ -0,0 +1,36 @@ +package imap + +import ( + "context" + "testing" +) + +func TestNewClient_WithTokenSource(t *testing.T) { + cfg := &Config{ + Host: "outlook.office365.com", + Port: 993, + TLS: true, + Username: "user@company.com", + AuthMethod: AuthXOAuth2, + } + called := false + ts := func(ctx context.Context) (string, error) { + called = true + return "test-token", nil + } + c := NewClient(cfg, "", WithTokenSource(ts)) + if c.tokenSource == nil { + t.Fatal("tokenSource should be set") + } + // Verify the token source is callable + token, err := c.tokenSource(context.Background()) + if err != nil { + t.Fatal(err) + } + if token != "test-token" { + t.Errorf("token = %q, want %q", token, "test-token") + } + if !called { + t.Error("token source was not called") + } +} diff --git a/internal/imap/config.go b/internal/imap/config.go index 9b9e31c4..f30ebc90 100644 --- a/internal/imap/config.go +++ b/internal/imap/config.go @@ -8,13 +8,33 @@ import ( "strconv" ) +// AuthMethod specifies how the IMAP client authenticates. +type AuthMethod string + +const ( + // AuthPassword uses traditional LOGIN (username + password). + AuthPassword AuthMethod = "password" + // AuthXOAuth2 uses XOAUTH2 SASL mechanism (OAuth2 bearer token). + AuthXOAuth2 AuthMethod = "xoauth2" +) + // Config holds connection settings for an IMAP server. type Config struct { - Host string `json:"host"` - Port int `json:"port"` - TLS bool `json:"tls"` // Implicit TLS (IMAPS, port 993) - STARTTLS bool `json:"starttls"` // STARTTLS upgrade (port 143) - Username string `json:"username"` + Host string `json:"host"` + Port int `json:"port"` + TLS bool `json:"tls"` // Implicit TLS (IMAPS, port 993) + STARTTLS bool `json:"starttls"` // STARTTLS upgrade (port 143) + Username string `json:"username"` + AuthMethod AuthMethod `json:"auth_method,omitempty"` +} + +// EffectiveAuthMethod returns the auth method, defaulting to password +// when the field is empty (backward compatibility with existing configs). +func (c *Config) EffectiveAuthMethod() AuthMethod { + if c.AuthMethod == "" { + return AuthPassword + } + return c.AuthMethod } // Addr returns the "host:port" string. diff --git a/internal/imap/config_test.go b/internal/imap/config_test.go index dce8a88a..532bc959 100644 --- a/internal/imap/config_test.go +++ b/internal/imap/config_test.go @@ -111,3 +111,30 @@ func TestParseIdentifier_InvalidScheme(t *testing.T) { t.Error("expected error for unsupported scheme") } } + +func TestConfigAuthMethod_DefaultsToPassword(t *testing.T) { + // Existing JSON without auth_method should default to password + cfg, err := ConfigFromJSON(`{"host":"imap.example.com","port":993,"tls":true,"username":"user"}`) + if err != nil { + t.Fatal(err) + } + if cfg.AuthMethod != "" && cfg.AuthMethod != AuthPassword { + t.Errorf("AuthMethod = %q, want empty or %q", cfg.AuthMethod, AuthPassword) + } + if cfg.EffectiveAuthMethod() != AuthPassword { + t.Errorf("EffectiveAuthMethod() = %q, want %q", cfg.EffectiveAuthMethod(), AuthPassword) + } +} + +func TestConfigAuthMethod_XOAuth2(t *testing.T) { + cfg, err := ConfigFromJSON(`{"host":"outlook.office365.com","port":993,"tls":true,"username":"user@company.com","auth_method":"xoauth2"}`) + if err != nil { + t.Fatal(err) + } + if cfg.AuthMethod != AuthXOAuth2 { + t.Errorf("AuthMethod = %q, want %q", cfg.AuthMethod, AuthXOAuth2) + } + if cfg.EffectiveAuthMethod() != AuthXOAuth2 { + t.Errorf("EffectiveAuthMethod() = %q, want %q", cfg.EffectiveAuthMethod(), AuthXOAuth2) + } +} diff --git a/internal/imap/xoauth2.go b/internal/imap/xoauth2.go new file mode 100644 index 00000000..9cdc8f82 --- /dev/null +++ b/internal/imap/xoauth2.go @@ -0,0 +1,34 @@ +package imap + +import ( + "fmt" + + "github.com/emersion/go-sasl" +) + +// xoauth2Client implements sasl.Client for the XOAUTH2 mechanism +// used by Microsoft Exchange Online and Gmail IMAP. +// +// The initial response format is: +// +// "user=" + username + "\x01" + "auth=Bearer " + token + "\x01\x01" +// +// See https://developers.google.com/gmail/imap/xoauth2-protocol +type xoauth2Client struct { + username string + token string +} + +// NewXOAuth2Client creates a SASL client for XOAUTH2 authentication. +func NewXOAuth2Client(username, token string) sasl.Client { + return &xoauth2Client{username: username, token: token} +} + +func (c *xoauth2Client) Start() (mech string, ir []byte, err error) { + resp := "user=" + c.username + "\x01auth=Bearer " + c.token + "\x01\x01" + return "XOAUTH2", []byte(resp), nil +} + +func (c *xoauth2Client) Next(challenge []byte) ([]byte, error) { + return nil, fmt.Errorf("XOAUTH2: unexpected server challenge") +} diff --git a/internal/imap/xoauth2_test.go b/internal/imap/xoauth2_test.go new file mode 100644 index 00000000..723abcd4 --- /dev/null +++ b/internal/imap/xoauth2_test.go @@ -0,0 +1,51 @@ +package imap + +import "testing" + +func TestXOAuth2Client_Start(t *testing.T) { + tests := []struct { + name string + username string + token string + wantMech string + wantIR string + }{ + { + name: "basic", + username: "user@example.com", + token: "ya29.access-token", + wantMech: "XOAUTH2", + wantIR: "user=user@example.com\x01auth=Bearer ya29.access-token\x01\x01", + }, + { + name: "empty token", + username: "user@example.com", + token: "", + wantMech: "XOAUTH2", + wantIR: "user=user@example.com\x01auth=Bearer \x01\x01", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewXOAuth2Client(tt.username, tt.token) + mech, ir, err := c.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + if mech != tt.wantMech { + t.Errorf("mech = %q, want %q", mech, tt.wantMech) + } + if string(ir) != tt.wantIR { + t.Errorf("ir = %q, want %q", string(ir), tt.wantIR) + } + }) + } +} + +func TestXOAuth2Client_Next(t *testing.T) { + c := NewXOAuth2Client("user@example.com", "token") + _, err := c.Next([]byte("some challenge")) + if err == nil { + t.Fatal("Next() should return error (XOAUTH2 is single-step)") + } +} diff --git a/internal/microsoft/oauth.go b/internal/microsoft/oauth.go new file mode 100644 index 00000000..b83531fd --- /dev/null +++ b/internal/microsoft/oauth.go @@ -0,0 +1,360 @@ +package microsoft + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/wesm/msgvault/internal/fileutil" + "golang.org/x/oauth2" +) + +const ( + DefaultTenant = "common" + ScopeIMAP = "https://outlook.office365.com/IMAP.AccessAsUser.All" + + redirectPort = "8089" + callbackPath = "/callback/microsoft" + graphMeEndpoint = "https://graph.microsoft.com/v1.0/me" +) + +var Scopes = []string{ + ScopeIMAP, + "offline_access", + "openid", + "email", + "User.Read", // required for MS Graph /me to validate email +} + +type TokenMismatchError struct { + Expected string + Actual string +} + +func (e *TokenMismatchError) Error() string { + return fmt.Sprintf("token mismatch: expected %s but authorized as %s", e.Expected, e.Actual) +} + +type Manager struct { + clientID string + tenantID string + tokensDir string + logger *slog.Logger + graphURL string // override for testing + + browserFlowFn func(ctx context.Context, email string) (*oauth2.Token, error) +} + +func NewManager(clientID, tenantID, tokensDir string, logger *slog.Logger) *Manager { + if tenantID == "" { + tenantID = DefaultTenant + } + if logger == nil { + logger = slog.Default() + } + return &Manager{ + clientID: clientID, + tenantID: tenantID, + tokensDir: tokensDir, + logger: logger, + } +} + +func (m *Manager) oauthConfig() *oauth2.Config { + return &oauth2.Config{ + ClientID: m.clientID, + Endpoint: oauth2.Endpoint{ + AuthURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", m.tenantID), + TokenURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", m.tenantID), + }, + RedirectURL: "http://localhost:" + redirectPort + callbackPath, + Scopes: Scopes, + } +} + +func (m *Manager) Authorize(ctx context.Context, email string) error { + flow := m.browserFlow + if m.browserFlowFn != nil { + flow = m.browserFlowFn + } + token, err := flow(ctx, email) + if err != nil { + return err + } + if _, err := m.resolveTokenEmail(ctx, email, token); err != nil { + return err + } + return m.saveToken(email, token, Scopes) +} + +// TokenSource returns a function that provides fresh access tokens. +// Suitable for passing to imap.WithTokenSource. +func (m *Manager) TokenSource(ctx context.Context, email string) (func(context.Context) (string, error), error) { + tf, err := m.loadTokenFile(email) + if err != nil { + return nil, fmt.Errorf("no valid token for %s: %w", email, err) + } + + cfg := m.oauthConfig() + ts := cfg.TokenSource(ctx, &tf.Token) + + return func(callCtx context.Context) (string, error) { + tok, err := ts.Token() + if err != nil { + return "", fmt.Errorf("refresh Microsoft token: %w", err) + } + if tok.AccessToken != tf.Token.AccessToken { + if saveErr := m.saveToken(email, tok, tf.Scopes); saveErr != nil { + m.logger.Warn("failed to save refreshed token", "email", email, "error", saveErr) + } + tf.Token = *tok + } + return tok.AccessToken, nil + }, nil +} + +func (m *Manager) browserFlow(ctx context.Context, email string) (*oauth2.Token, error) { + cfg := m.oauthConfig() + + // PKCE (required by Azure AD for public clients) + verifierBytes := make([]byte, 32) + if _, err := rand.Read(verifierBytes); err != nil { + return nil, fmt.Errorf("generate PKCE verifier: %w", err) + } + verifier := base64.RawURLEncoding.EncodeToString(verifierBytes) + challengeHash := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(challengeHash[:]) + + // CSRF state + stateBytes := make([]byte, 16) + if _, err := rand.Read(stateBytes); err != nil { + return nil, fmt.Errorf("generate state: %w", err) + } + state := base64.URLEncoding.EncodeToString(stateBytes) + + codeChan := make(chan string, 1) + errChan := make(chan error, 1) + + mux := http.NewServeMux() + mux.HandleFunc(callbackPath, func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + errChan <- fmt.Errorf("state mismatch: possible CSRF attack") + fmt.Fprintf(w, "Error: state mismatch") + return + } + if errMsg := r.URL.Query().Get("error"); errMsg != "" { + desc := r.URL.Query().Get("error_description") + errChan <- fmt.Errorf("Microsoft OAuth error: %s: %s", errMsg, desc) + fmt.Fprintf(w, "Error: %s", desc) + return + } + code := r.URL.Query().Get("code") + if code == "" { + errChan <- fmt.Errorf("no code in callback") + fmt.Fprintf(w, "Error: no authorization code received") + return + } + codeChan <- code + fmt.Fprintf(w, "Authorization successful! You can close this window.") + }) + + server := &http.Server{Addr: "localhost:" + redirectPort, Handler: mux} + go func() { + if err := server.ListenAndServe(); err != http.ErrServerClosed { + errChan <- err + } + }() + defer func() { _ = server.Shutdown(ctx) }() + + authURL := cfg.AuthCodeURL(state, + oauth2.SetAuthURLParam("code_challenge", challenge), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + oauth2.SetAuthURLParam("login_hint", email), + ) + + fmt.Printf("Opening browser for Microsoft authorization...\n") + fmt.Printf("If browser doesn't open, visit:\n%s\n\n", authURL) + if err := openBrowser(authURL); err != nil { + m.logger.Warn("failed to open browser", "error", err) + } + + select { + case code := <-codeChan: + return cfg.Exchange(ctx, code, + oauth2.SetAuthURLParam("code_verifier", verifier), + ) + case err := <-errChan: + return nil, err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +const resolveTimeout = 10 * time.Second + +func (m *Manager) resolveTokenEmail(ctx context.Context, email string, token *oauth2.Token) (string, error) { + valCtx, cancel := context.WithTimeout(ctx, resolveTimeout) + defer cancel() + + cfg := m.oauthConfig() + ts := cfg.TokenSource(valCtx, token) + client := oauth2.NewClient(valCtx, ts) + + graphURL := m.graphURL + if graphURL == "" { + graphURL = graphMeEndpoint + } + req, err := http.NewRequestWithContext(valCtx, "GET", graphURL, nil) + if err != nil { + return "", fmt.Errorf("create graph request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("verify Microsoft account: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("MS Graph returned HTTP %d: %s", resp.StatusCode, string(body)) + } + + var profile struct { + Mail string `json:"mail"` + UserPrincipalName string `json:"userPrincipalName"` + } + if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil { + return "", fmt.Errorf("parse MS Graph profile: %w", err) + } + + actual := profile.Mail + if actual == "" { + actual = profile.UserPrincipalName + } + if !strings.EqualFold(actual, email) { + return "", &TokenMismatchError{Expected: email, Actual: actual} + } + + return actual, nil +} + +// --- Token storage --- + +type tokenFile struct { + oauth2.Token + Scopes []string `json:"scopes,omitempty"` +} + +func (m *Manager) TokenPath(email string) string { + safe := sanitizeEmail(email) + return filepath.Join(m.tokensDir, "microsoft_"+safe+".json") +} + +func (m *Manager) saveToken(email string, token *oauth2.Token, scopes []string) error { + if err := fileutil.SecureMkdirAll(m.tokensDir, 0700); err != nil { + return err + } + + tf := tokenFile{Token: *token, Scopes: scopes} + data, err := json.MarshalIndent(tf, "", " ") + if err != nil { + return err + } + + path := m.TokenPath(email) + tmpFile, err := os.CreateTemp(m.tokensDir, ".ms-token-*.tmp") + if err != nil { + return fmt.Errorf("create temp token file: %w", err) + } + tmpPath := tmpFile.Name() + + if _, err := tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + return fmt.Errorf("write temp token file: %w", err) + } + if err := tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("close temp token file: %w", err) + } + if err := fileutil.SecureChmod(tmpPath, 0600); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("chmod temp token file: %w", err) + } + if err := os.Rename(tmpPath, path); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("rename temp token file: %w", err) + } + return nil +} + +func (m *Manager) loadTokenFile(email string) (*tokenFile, error) { + path := m.TokenPath(email) + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var tf tokenFile + if err := json.Unmarshal(data, &tf); err != nil { + return nil, err + } + return &tf, nil +} + +func (m *Manager) HasToken(email string) bool { + _, err := os.Stat(m.TokenPath(email)) + return err == nil +} + +func (m *Manager) DeleteToken(email string) error { + err := os.Remove(m.TokenPath(email)) + if os.IsNotExist(err) { + return nil + } + return err +} + +func sanitizeEmail(email string) string { + safe := strings.ReplaceAll(email, "/", "_") + safe = strings.ReplaceAll(safe, "\\", "_") + safe = strings.ReplaceAll(safe, "..", "_..") + return safe +} + +func openBrowser(rawURL string) error { + parsed, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + scheme := strings.ToLower(parsed.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("refused to open URL with scheme %q", parsed.Scheme) + } + + var cmd *exec.Cmd + switch runtime.GOOS { + case "darwin": + cmd = exec.Command("open", rawURL) + case "linux": + cmd = exec.Command("xdg-open", rawURL) + case "windows": + cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", rawURL) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } + return cmd.Start() +} diff --git a/internal/microsoft/oauth_test.go b/internal/microsoft/oauth_test.go new file mode 100644 index 00000000..7a506dfd --- /dev/null +++ b/internal/microsoft/oauth_test.go @@ -0,0 +1,195 @@ +package microsoft + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "golang.org/x/oauth2" +) + +func TestTokenPath(t *testing.T) { + m := &Manager{tokensDir: "/tmp/tokens"} + path := m.TokenPath("user@example.com") + want := "/tmp/tokens/microsoft_user@example.com.json" + if path != want { + t.Errorf("TokenPath = %q, want %q", path, want) + } +} + +func TestSaveAndLoadToken(t *testing.T) { + dir := t.TempDir() + m := &Manager{tokensDir: dir} + token := &oauth2.Token{ + AccessToken: "access-123", + RefreshToken: "refresh-456", + TokenType: "Bearer", + } + scopes := []string{"IMAP.AccessAsUser.All", "offline_access"} + + if err := m.saveToken("user@example.com", token, scopes); err != nil { + t.Fatal(err) + } + + loaded, err := m.loadTokenFile("user@example.com") + if err != nil { + t.Fatal(err) + } + if loaded.AccessToken != "access-123" { + t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, "access-123") + } + if loaded.RefreshToken != "refresh-456" { + t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, "refresh-456") + } + if len(loaded.Scopes) != 2 { + t.Errorf("Scopes len = %d, want 2", len(loaded.Scopes)) + } + + // Verify file permissions + path := m.TokenPath("user@example.com") + info, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if info.Mode().Perm() != 0600 { + t.Errorf("permissions = %o, want 0600", info.Mode().Perm()) + } +} + +func TestHasToken(t *testing.T) { + dir := t.TempDir() + m := &Manager{tokensDir: dir} + + if m.HasToken("nobody@example.com") { + t.Error("HasToken should be false for non-existent token") + } + + token := &oauth2.Token{AccessToken: "test"} + if err := m.saveToken("user@example.com", token, nil); err != nil { + t.Fatal(err) + } + if !m.HasToken("user@example.com") { + t.Error("HasToken should be true after save") + } +} + +func TestDeleteToken(t *testing.T) { + dir := t.TempDir() + m := &Manager{tokensDir: dir} + + token := &oauth2.Token{AccessToken: "test"} + if err := m.saveToken("user@example.com", token, nil); err != nil { + t.Fatal(err) + } + if err := m.DeleteToken("user@example.com"); err != nil { + t.Fatal(err) + } + if m.HasToken("user@example.com") { + t.Error("HasToken should be false after delete") + } + // Delete non-existent should not error + if err := m.DeleteToken("nobody@example.com"); err != nil { + t.Errorf("DeleteToken non-existent: %v", err) + } +} + +func TestSanitizeEmail(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"user@example.com", "user@example.com"}, + {"../evil", "_.._evil"}, + {"a/b", "a_b"}, + {"a\\b", "a_b"}, + } + for _, tt := range tests { + got := sanitizeEmail(tt.input) + if got != tt.want { + t.Errorf("sanitizeEmail(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestResolveTokenEmail_Match(t *testing.T) { + // Mock MS Graph /me endpoint + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "mail": "user@example.com", + "userPrincipalName": "user@example.com", + }) + })) + defer server.Close() + + m := &Manager{ + clientID: "test-client", + tenantID: "common", + tokensDir: t.TempDir(), + graphURL: server.URL, + } + + token := &oauth2.Token{AccessToken: "test-token", TokenType: "Bearer"} + actual, err := m.resolveTokenEmail(t.Context(), "user@example.com", token) + if err != nil { + t.Fatal(err) + } + if actual != "user@example.com" { + t.Errorf("actual = %q, want %q", actual, "user@example.com") + } +} + +func TestResolveTokenEmail_Mismatch(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "mail": "other@example.com", + "userPrincipalName": "other@example.com", + }) + })) + defer server.Close() + + m := &Manager{ + clientID: "test-client", + tenantID: "common", + tokensDir: t.TempDir(), + graphURL: server.URL, + } + + token := &oauth2.Token{AccessToken: "test-token", TokenType: "Bearer"} + _, err := m.resolveTokenEmail(t.Context(), "user@example.com", token) + if err == nil { + t.Fatal("expected error for mismatch") + } + _, ok := err.(*TokenMismatchError) + if !ok { + t.Errorf("expected *TokenMismatchError, got %T: %v", err, err) + } +} + +func TestResolveTokenEmail_FallbackToUPN(t *testing.T) { + // Some accounts have empty mail, only userPrincipalName + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "mail": "", + "userPrincipalName": "user@example.com", + }) + })) + defer server.Close() + + m := &Manager{ + clientID: "test-client", + tenantID: "common", + tokensDir: t.TempDir(), + graphURL: server.URL, + } + + token := &oauth2.Token{AccessToken: "test-token", TokenType: "Bearer"} + actual, err := m.resolveTokenEmail(t.Context(), "user@example.com", token) + if err != nil { + t.Fatal(err) + } + if actual != "user@example.com" { + t.Errorf("actual = %q, want %q", actual, "user@example.com") + } +}