Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: simplify token reuse algorithm #1072

Merged
merged 3 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 18 additions & 38 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,39 +311,16 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
}
}

var latestValidToken *models.RefreshToken
if token.Revoked {
a.clearCookieTokens(config, w)
/**
For a revoked refresh token to be reused, it has to fulfil 2 conditions:
1. The revoked refresh token has to be the parent of the most recently issued valid refresh token.
2. The revoked refresh token has to fall within the reuse interval.
**/
err = db.Transaction(func(tx *storage.Connection) error {
validChildToken, terr := models.GetValidChildToken(tx, token)
if terr != nil {
if errors.Is(terr, models.RefreshTokenNotFoundError{}) {
// revoked token has no valid child token
return nil
}
return terr
}
// check if the parent of the child token is the revoked token being used
// validChildToken.Parent and token.Token will never be null here
if validChildToken.Parent == storage.NullString(token.Token) {
refreshTokenReuseWindow := token.UpdatedAt.Add(time.Second * time.Duration(config.Security.RefreshTokenReuseInterval))
// check if the revoked token falls within the reuse interval
if time.Now().Before(refreshTokenReuseWindow) {
latestValidToken = validChildToken
}
}
return nil
})
if err != nil {
return internalServerError("Error validating reuse interval").WithInternalError(err)
}
// For a revoked refresh token to be reused, it has to fall within the reuse interval.

reuseUntil := token.UpdatedAt.Add(
time.Second * time.Duration(config.Security.RefreshTokenReuseInterval))

if time.Now().After(reuseUntil) {
// not OK to reuse this token

if latestValidToken == nil {
if config.Security.RefreshTokenRotationEnabled {
// Revoke all tokens in token family
err = db.Transaction(func(tx *storage.Connection) error {
Expand All @@ -357,7 +334,8 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return internalServerError(err.Error())
}
}
return oauthError("invalid_grant", "Invalid Refresh Token").WithInternalMessage("Possible abuse attempt: %v", r)

return oauthError("invalid_grant", "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID)
}
}

Expand All @@ -370,13 +348,15 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return terr
}

if latestValidToken == nil {
latestValidToken, terr = models.GrantRefreshTokenSwap(r, tx, user, token)
if terr != nil {
return internalServerError(terr.Error())
}
// a new refresh token is generated and explicitly not reusing
// a previous one as it could have already been revoked while
// this handler was running
newToken, terr := models.GrantRefreshTokenSwap(r, tx, user, token)
if terr != nil {
return terr
}
tokenString, terr = generateAccessToken(tx, user, latestValidToken.SessionId, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret)

tokenString, terr = generateAccessToken(tx, user, newToken.SessionId, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret)

if terr != nil {
return internalServerError("error generating jwt token").WithInternalError(terr)
Expand All @@ -386,7 +366,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
Token: tokenString,
TokenType: "bearer",
ExpiresIn: config.JWT.Exp,
RefreshToken: latestValidToken.Token,
RefreshToken: newToken.Token,
User: user,
}
if terr = a.setCookieTokens(config, newTokenResponse, false, w); terr != nil {
Expand Down
29 changes: 10 additions & 19 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
require.NoError(ts.T(), err)
second, err := models.GrantRefreshTokenSwap(&http.Request{}, ts.API.db, u, first)
require.NoError(ts.T(), err)
third, err := models.GrantRefreshTokenSwap(&http.Request{}, ts.API.db, u, second)
require.NoError(ts.T(), err)

cases := []struct {
desc string
Expand All @@ -232,40 +230,29 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
refreshToken: second.Token,
expectedCode: http.StatusOK,
expectedBody: map[string]interface{}{
"refresh_token": third.Token,
"refresh_token": "some-new-refresh-token",
},
},
{
desc: "Invalid refresh, first token is not the previous revoked token",
desc: "Invalid refresh outside reuse interval",
refreshTokenRotationEnabled: true,
reuseInterval: 0,
refreshToken: first.Token,
expectedCode: http.StatusBadRequest,
expectedBody: map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token",
"error_description": "Invalid Refresh Token: Already Used",
},
},
{
desc: "Invalid refresh, revoked third token",
desc: "Invalid refresh, revoke third token",
refreshTokenRotationEnabled: true,
reuseInterval: 0,
refreshToken: second.Token,
expectedCode: http.StatusBadRequest,
expectedBody: map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token",
},
},
{
desc: "Invalid refresh, third token revoked by previous case",
refreshTokenRotationEnabled: true,
reuseInterval: 30,
refreshToken: third.Token,
expectedCode: http.StatusBadRequest,
expectedBody: map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token",
"error_description": "Invalid Refresh Token: Already Used",
},
},
}
Expand All @@ -287,7 +274,11 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
data := make(map[string]interface{})
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
for k, v := range c.expectedBody {
require.Equal(ts.T(), v, data[k])
if k == "refresh_token" {
require.NotEmpty(ts.T(), v, data[k])
} else {
require.Equal(ts.T(), v, data[k])
}
}
})
}
Expand Down