Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const (

//go:generate mockgen -source=auth.go -destination=test/auth_mocks.go -package=test Client
type Client interface {
ProviderSessionAccessor
ProviderSessionClient
RestrictedTokenAccessor
ExternalAccessor
permission.Client
Expand Down
21 changes: 10 additions & 11 deletions auth/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,17 @@ func NewClient(cfg *Config, authorizeAs platform.AuthorizeAs, name string, lgr l
}, nil
}

func (c *Client) CreateUserProviderSession(ctx context.Context, userID string, create *auth.ProviderSessionCreate) (*auth.ProviderSession, error) {
func (c *Client) CreateProviderSession(ctx context.Context, create *auth.ProviderSessionCreate) (*auth.ProviderSession, error) {
if ctx == nil {
return nil, errors.New("context is missing")
}
if userID == "" {
return nil, errors.New("user id is missing")
}
if create == nil {
return nil, errors.New("create is missing")
} else if err := structureValidator.New(log.LoggerFromContext(ctx)).Validate(create); err != nil {
return nil, errors.Wrap(err, "create is invalid")
}

url := c.client.ConstructURL("v1", "users", userID, "provider_sessions")
url := c.client.ConstructURL("v1", "provider_sessions")
providerSession := &auth.ProviderSession{}
if err := c.client.RequestData(ctx, http.MethodPost, url, nil, create, providerSession); err != nil {
return nil, err
Expand All @@ -94,24 +91,26 @@ func (c *Client) CreateUserProviderSession(ctx context.Context, userID string, c
return providerSession, nil
}

func (c *Client) DeleteUserProviderSessions(ctx context.Context, userID string) error {
func (c *Client) DeleteProviderSessions(ctx context.Context, filter *auth.ProviderSessionFilter) error {
if ctx == nil {
return errors.New("context is missing")
}
if userID == "" {
return errors.New("user id is missing")
if filter == nil {
return errors.New("filter is missing")
} else if err := structureValidator.New(log.LoggerFromContext(ctx)).Validate(filter); err != nil {
return errors.Wrap(err, "filter is invalid")
}

url := c.client.ConstructURL("v1", "users", userID, "provider_sessions")
return c.client.RequestData(ctx, http.MethodDelete, url, nil, nil, nil)
url := c.client.ConstructURL("v1", "provider_sessions")
return c.client.RequestData(ctx, http.MethodDelete, url, []request.RequestMutator{filter}, nil, nil)
}

func (c *Client) ListProviderSessions(ctx context.Context, filter *auth.ProviderSessionFilter, pagination *page.Pagination) (auth.ProviderSessions, error) {
if ctx == nil {
return nil, errors.New("context is missing")
}
if filter == nil {
filter = auth.NewProviderSessionFilter()
return nil, errors.New("filter is missing")
} else if err := structureValidator.New(log.LoggerFromContext(ctx)).Validate(filter); err != nil {
return nil, errors.Wrap(err, "filter is invalid")
}
Expand Down
2 changes: 1 addition & 1 deletion auth/events/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (u *userDeletionEventsHandler) HandleDeleteUserEvent(payload ev.DeleteUserE
}

logger.Infof("Deleting provider sessions for user")
if err := u.client.DeleteUserProviderSessions(u.ctx, payload.UserID); err != nil {
if err := u.client.DeleteProviderSessions(u.ctx, &auth.ProviderSessionFilter{UserID: &payload.UserID}); err != nil {
errs = append(errs, err)
logger.WithError(err).Error("unable to delete provider sessions for user")
}
Expand Down
40 changes: 13 additions & 27 deletions auth/provider_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/tidepool-org/platform/errors"
"github.com/tidepool-org/platform/id"
"github.com/tidepool-org/platform/log"
"github.com/tidepool-org/platform/page"
"github.com/tidepool-org/platform/request"
"github.com/tidepool-org/platform/structure"
Expand All @@ -26,11 +25,11 @@ func ProviderTypes() []string {
}
}

type ProviderSessionAccessor interface {
CreateUserProviderSession(ctx context.Context, userID string, create *ProviderSessionCreate) (*ProviderSession, error)
DeleteUserProviderSessions(ctx context.Context, userID string) error

type ProviderSessionClient interface {
ListProviderSessions(ctx context.Context, filter *ProviderSessionFilter, pagination *page.Pagination) (ProviderSessions, error)
DeleteProviderSessions(ctx context.Context, filter *ProviderSessionFilter) error

CreateProviderSession(ctx context.Context, create *ProviderSessionCreate) (*ProviderSession, error)
GetProviderSession(ctx context.Context, id string) (*ProviderSession, error)
UpdateProviderSession(ctx context.Context, id string, update *ProviderSessionUpdate) (*ProviderSession, error)
DeleteProviderSession(ctx context.Context, id string) error
Expand Down Expand Up @@ -59,6 +58,9 @@ func (p *ProviderSessionFilter) Validate(validator structure.Validator) {
validator.String("type", p.Type).OneOf(ProviderTypes()...)
validator.String("name", p.Name).Using(ProviderNameValidator)
validator.String("externalId", p.ExternalID).Using(ProviderExternalIDValidator)
if p.UserID == nil && p.ExternalID == nil {
validator.ReportError(structureValidator.ErrorValuesNotExistForAny("externalId", "userId"))
}
}

func (p *ProviderSessionFilter) MutateRequest(req *http.Request) error {
Expand All @@ -79,6 +81,7 @@ func (p *ProviderSessionFilter) MutateRequest(req *http.Request) error {
}

type ProviderSessionCreate struct {
UserID string `json:"userId" bson:"userId"`
Type string `json:"type" bson:"type"`
Name string `json:"name" bson:"name"`
OAuthToken *OAuthToken `json:"oauthToken,omitempty" bson:"oauthToken,omitempty"`
Expand All @@ -90,6 +93,9 @@ func NewProviderSessionCreate() *ProviderSessionCreate {
}

func (p *ProviderSessionCreate) Parse(parser structure.ObjectParser) {
if ptr := parser.String("userId"); ptr != nil {
p.UserID = *ptr
}
if ptr := parser.String("type"); ptr != nil {
p.Type = *ptr
}
Expand All @@ -105,6 +111,7 @@ func (p *ProviderSessionCreate) Parse(parser structure.ObjectParser) {
}

func (p *ProviderSessionCreate) Validate(validator structure.Validator) {
validator.String("userId", &p.UserID).Using(user.IDValidator)
validator.String("type", &p.Type).OneOf(ProviderTypes()...)
validator.String("name", &p.Name).Using(ProviderNameValidator)
switch p.Type {
Expand Down Expand Up @@ -223,27 +230,6 @@ type ProviderSession struct {
ModifiedTime *time.Time `json:"modifiedTime,omitempty" bson:"modifiedTime,omitempty"`
}

func NewProviderSession(ctx context.Context, userID string, create *ProviderSessionCreate) (*ProviderSession, error) {
if userID == "" {
return nil, errors.New("user id is missing")
}
if create == nil {
return nil, errors.New("create is missing")
} else if err := structureValidator.New(log.LoggerFromContext(ctx)).Validate(create); err != nil {
return nil, errors.Wrap(err, "create is invalid")
}

return &ProviderSession{
ID: NewProviderSessionID(),
UserID: userID,
Type: create.Type,
Name: create.Name,
OAuthToken: create.OAuthToken,
ExternalID: create.ExternalID,
CreatedTime: time.Now(),
}, nil
}

func (p *ProviderSession) Parse(parser structure.ObjectParser) {
if ptr := parser.String("id"); ptr != nil {
p.ID = *ptr
Expand Down Expand Up @@ -271,7 +257,7 @@ func (p *ProviderSession) Parse(parser structure.ObjectParser) {

func (p *ProviderSession) Validate(validator structure.Validator) {
validator.String("id", &p.ID).Using(ProviderSessionIDValidator)
validator.String("userId", &p.UserID).Using(UserIDValidator)
validator.String("userId", &p.UserID).Using(user.IDValidator)
validator.String("type", &p.Type).OneOf(ProviderTypes()...)
validator.String("name", &p.Name).Using(ProviderNameValidator)
switch p.Type {
Expand Down
4 changes: 2 additions & 2 deletions auth/providersession/provider_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ package providersession

import "github.com/tidepool-org/platform/auth"

//go:generate mockgen -destination=test/provider_session_mocks.go -package=test --mock_names=ProviderSessionAccessor=MockClient github.com/tidepool-org/platform/auth ProviderSessionAccessor
type Client auth.ProviderSessionAccessor
//go:generate mockgen -destination=test/provider_session_mocks.go -package=test --mock_names=ProviderSessionClient=MockClient github.com/tidepool-org/platform/auth ProviderSessionClient
type Client auth.ProviderSessionClient
30 changes: 15 additions & 15 deletions auth/providersession/test/provider_session_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 17 additions & 36 deletions auth/service/api/v1/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/tidepool-org/platform/errors"
"github.com/tidepool-org/platform/log"
"github.com/tidepool-org/platform/oauth"
"github.com/tidepool-org/platform/page"
"github.com/tidepool-org/platform/pointer"
"github.com/tidepool-org/platform/provider"
"github.com/tidepool-org/platform/request"
Expand Down Expand Up @@ -60,11 +59,14 @@ func (r *Router) OAuthProviderAuthorizeGet(res rest.ResponseWriter, req *rest.Re
if err != nil {
r.htmlOnError(res, req, err)
return
} else if restrictedToken == nil {
r.htmlOnError(res, req, request.ErrorUnauthenticated(), usedError)
return
}

maxAge := time.Until(restrictedToken.ExpirationTime) / time.Second
if maxAge <= 0 {
r.htmlOnError(res, req, request.ErrorUnauthenticated())
r.htmlOnError(res, req, request.ErrorUnauthenticated(), expiredError)
return
}

Expand Down Expand Up @@ -121,23 +123,11 @@ func (r *Router) UserOAuthProviderAuthorizeDelete(res rest.ResponseWriter, req *
providerSessionFilter.UserID = pointer.FromString(userID)
providerSessionFilter.Type = pointer.FromString(prvdr.Type())
providerSessionFilter.Name = pointer.FromString(prvdr.Name())
providerSessions, err := r.AuthClient().ListProviderSessions(ctx, providerSessionFilter, page.NewPagination())
if err != nil {
if err := r.AuthClient().DeleteProviderSessions(ctx, providerSessionFilter); err != nil {
responder.Error(http.StatusInternalServerError, err)
return
}

if len(providerSessions) > 1 {
r.Logger().WithFields(log.Fields{"userId": userID, "filter": providerSessionFilter, "providerSessions": providerSessions}).Warn("Deleting multiple provider sessions")
}

for _, providerSession := range providerSessions {
if err = r.AuthClient().DeleteProviderSession(ctx, providerSession.ID); err != nil {
responder.Error(http.StatusInternalServerError, err)
return
}
}

responder.Empty(http.StatusOK)
}

Expand Down Expand Up @@ -190,10 +180,6 @@ func (r *Router) OAuthProviderRedirectGet(res rest.ResponseWriter, req *rest.Req
responder.SetCookie(r.providerCookie(prvdr, restrictedToken.ID, -1))
}

if err = r.AuthClient().DeleteRestrictedToken(ctx, restrictedToken.ID); err != nil {
log.LoggerFromContext(ctx).WithError(err).Error("unable to delete restricted token after oauth redirect")
}

if errorCode := query.Get("error"); prvdr.IsErrorCodeAccessDenied(errorCode) {
html := fmt.Sprintf(htmlOnRedirect, redirectURLDeclined.String())
r.htmlOnRedirect(res, req, html)
Expand All @@ -207,20 +193,9 @@ func (r *Router) OAuthProviderRedirectGet(res rest.ResponseWriter, req *rest.Req
filter.UserID = pointer.FromString(restrictedToken.UserID)
filter.Type = pointer.FromString(prvdr.Type())
filter.Name = pointer.FromString(prvdr.Name())
providerSessions, err := r.AuthClient().ListProviderSessions(ctx, filter, nil)
if err != nil {
if err := r.AuthClient().DeleteProviderSessions(ctx, filter); err != nil {
r.htmlOnError(res, req, err)
return
} else if len(providerSessions) > 0 {
// Delete existing provider sessions and tasks if matching name and type found for user.
// This operation will also reset the data source to a `disconnected` state, and remove any associated tasks
// A new provider session and task will be created below which will update the existing data source state to `connected`.
for _, session := range providerSessions {
if deleteSessionErr := r.AuthClient().DeleteProviderSession(ctx, session.ID); deleteSessionErr != nil {
r.htmlOnError(res, req, errors.Newf("could not remove existing provider session"), alreadyConnectedError)
return
}
}
}

oauthToken, err := prvdr.ExchangeAuthorizationCodeForToken(ctx, query.Get("code"))
Expand All @@ -230,15 +205,19 @@ func (r *Router) OAuthProviderRedirectGet(res rest.ResponseWriter, req *rest.Req
}

providerSessionCreate := auth.NewProviderSessionCreate()
providerSessionCreate.UserID = restrictedToken.UserID
providerSessionCreate.Type = prvdr.Type()
providerSessionCreate.Name = prvdr.Name()
providerSessionCreate.OAuthToken = oauthToken
_, err = r.AuthClient().CreateUserProviderSession(ctx, restrictedToken.UserID, providerSessionCreate)
if err != nil {
if _, err = r.AuthClient().CreateProviderSession(ctx, providerSessionCreate); err != nil {
r.htmlOnError(res, req, err)
return
}

if err = r.AuthClient().DeleteRestrictedToken(ctx, restrictedToken.ID); err != nil {
log.LoggerFromContext(ctx).WithError(err).Error("unable to delete restricted token after oauth redirect")
}

html := fmt.Sprintf(htmlOnRedirect, redirectURLAuthorized.String())
r.htmlOnRedirect(res, req, html)
}
Expand Down Expand Up @@ -339,7 +318,7 @@ func (r *Router) htmlOnError(res rest.ResponseWriter, req *rest.Request, err err
log.LoggerFromContext(req.Context()).WithError(err).WithField("messages", messages).Error("Unexpected failure during OAuth workflow")
request.MustNewResponder(res, req).String(
request.StatusCodeForError(err),
strings.Replace(htmlOnError, "{{ MESSAGES }}", strings.Join(messages, " "), -1),
strings.ReplaceAll(htmlOnError, "{{ MESSAGES }}", strings.Join(messages, " ")),
request.NewHeaderMutator("Content-Type", "text/html"),
)
}
Expand Down Expand Up @@ -425,5 +404,7 @@ const htmlOnError = `
</body>
</html>
`
const unexpectedError = `Looks like an unexpected error occurred. You can try again, or send an email to support@tidepool.org for help.`
const alreadyConnectedError = `This Tidepool account has already been connected to a Dexcom account. If this doesn't sound right, please send an email to support@tidepool.org and we'll help you out.`

const unexpectedError = "Looks like an unexpected error occurred. You can try again, or send an email to support@tidepool.org for help."
const usedError = "This connection request has already been used and can only be used once. If this doesn't sound right, please send an email to support@tidepool.org and we'll help you out."
const expiredError = "This connection request has expired. If this doesn't sound right, please send an email to support@tidepool.org and we'll help you out."
Loading
Loading