ブラウザから直接API GatewayのエンドポイントにアクセスしたときにCognitoのTokenで認証し、失敗したらログイン画面を表示させる。 API GatewayでCognitoの認証をかける場合、AuthorizerでUserPoolを指定するのが最も簡単なパターンだが、 これだとHeaderにTokenを付けてアクセスする必要があり認証に失敗するとUnauthorizedが返る。
Cognito UserPoolとAPI Gatewayで認証付きAPIを立てる - sambaiz-net
なおAPI GatewayではなくALBをLambdaの前段に挟めば今回やることが簡単に実現できる。
LambdaとALBでCognito認証をかけて失敗したらログイン画面に飛ばす - sambaiz-net
準備
UserPoolとClientを作成する。 CloudFormationで作成する場合SchemaのMutableのデフォルトがfalseで、変えると作り直されてしまうことに注意。
Resources:
Userpool:
Type: AWS::Cognito::UserPool
Properties:
AdminCreateUserConfig:
AllowAdminCreateUserOnly: false
Schema:
- Mutable: true
Name: email
Required: true
- Mutable: true
Name: name
Required: true
UsernameAttributes:
- email
UserPoolName: testpool
UserpoolClient:
Type: AWS::Cognito::UserPoolClient
Properties:
UserPoolId:
Ref: Userpool
ClientName: testclient
GenerateSecret: true
その後、GoogleのOAuth Client IDを作成し、フェデレーションの設定を行ってGoogleアカウントでもログインできるようにした。 これはID Poolのフェデレーティッドアイデンティティとは異なる機能。 UserPoolのドメインや、外部IdPとのAttributes Mapping、Clientの設定はCloudFormationではできないので手で行う。
追記 (2020-12-06): 今はCloudFormationで行えるようになっている。
CDKでCognito UserPoolとClientを作成しトリガーやFederationを設定する - sambaiz-net
Attributes Mappingにusernameの項目があるが、変更不可な値なのでsubのままにしておく。
設定が終わったら https://<user-pool-domain>/login?client_id=<client-id>&redirect_uri=<redirect-uri>&response_type=code
にアクセスし、<redirect-uri>?code=<code>
までリダイレクトされることを確認する。
invalid_requestとだけ出てしまった場合はClientのAuthorization code grantにチェックが入っているか確認する。
実装
Authorization Flowを行う。
Cookieで渡ってきたTokenを検証し、失敗した場合はログイン画面にリダイレクトさせる。 identity_providerを指定するかClientに一つしか紐づいていなければそのIdPのログイン画面に飛び、そうでなければ選択する画面が表示される。 ログインしてクエリパラメータとして渡ってきたcodeとstateもこの関数で受け取り、検証してTokenを発行し一旦返してCookieに焼く。
OpenID ConnectのIDトークンの内容と検証 - sambaiz-net
通常stateはサーバーに保持されるがDBを持たなくていいようにCookieに焼いている。
CSRF対策のためのものなのでCookieが攻撃者に読み書きされなければ意味を為すと思っているが、このフローで問題がある場合は教えてほしい。
直接アクセスすることを想定してTokenをCookieに焼いているので、CORSの、特にAccess-Control-Allow-Origin: *
のような設定を行うと穴が開いてしまう。
package auth
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/aws/aws-lambda-go/events"
jwt "github.com/dgrijalva/jwt-go"
"github.com/lestrrat/go-jwx/jwk"
)
const (
stateCookieName = "st"
tokenCookieName = "ck"
)
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IdToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Error *string `json:"error"`
}
func AuthOrLogin(
request events.APIGatewayProxyRequest,
userPoolClientID string,
userPoolClientSecret string,
userPoolURL string,
userPoolID string,
region string) *events.APIGatewayProxyResponse {
myselfURL := fmt.Sprintf("https://%s/%s%s", request.Headers["Host"], request.RequestContext.Stage, request.Path)
if code := getCodeFromParams(request.QueryStringParameters); code != nil {
if err := checkState(request); err != nil {
fmt.Printf(err.Error())
return newErrorResponse()
}
tokenResp, err := requestTokenByCode(*code, userPoolClientID, userPoolClientSecret, myselfURL, userPoolURL)
if err != nil {
fmt.Println(err.Error())
return newErrorResponse()
}
resp := newRedirectResponse(myselfURL)
resp.MultiValueHeaders = map[string][]string{
"Set-Cookie": []string{
fmt.Sprintf("%s=%s; Secure; HttpOnly", tokenCookieName, tokenResp.AccessToken),
// delete state
fmt.Sprintf("%s=; Expires=Thu, 1-Jan-1990 00:00:00 GMT; Secure; HttpOnly", stateCookieName),
},
}
return resp
}
if token := getTokenFromCookie(request.Headers); token != nil {
claims, err := parseToken(*token, userPoolClientID, userPoolID, region)
if err == nil {
fmt.Println(claims)
return nil // ok
} else {
fmt.Println(err.Error()) // re-login
}
}
loginURL, state, err := makeLoginURL(userPoolURL, userPoolClientID, myselfURL)
if err != nil {
fmt.Println(err.Error())
return newErrorResponse()
}
resp := newRedirectResponse(loginURL)
resp.MultiValueHeaders = map[string][]string{
"Set-Cookie": []string{
// delete old token
fmt.Sprintf("%s=; Expires=Thu, 1-Jan-1990 00:00:00 GMT; Secure; HttpOnly", tokenCookieName),
fmt.Sprintf("%s=%s; Secure; HttpOnly", stateCookieName, state),
},
}
return resp
}
func getCodeFromParams(params map[string]string) *string {
if code, ok := params["code"]; ok {
return &code
}
return nil
}
func getStateFromCookie(headers map[string]string) *string {
if cookie, ok := headers["Cookie"]; ok {
header := http.Header{}
header.Add("Cookie", cookie)
request := http.Request{Header: header}
c, err := request.Cookie(stateCookieName)
if err != nil {
return nil
}
return &c.Value
}
return nil
}
func getTokenFromCookie(headers map[string]string) *string {
if cookie, ok := headers["Cookie"]; ok {
header := http.Header{}
header.Add("Cookie", cookie)
request := http.Request{Header: header}
c, err := request.Cookie(tokenCookieName)
if err != nil {
return nil
}
return &c.Value
}
return nil
}
func newRedirectResponse(location string) *events.APIGatewayProxyResponse {
return &events.APIGatewayProxyResponse{
StatusCode: http.StatusFound,
Headers: map[string]string{
"Location": location,
},
}
}
func newErrorResponse() *events.APIGatewayProxyResponse {
return &events.APIGatewayProxyResponse{
StatusCode: http.StatusInternalServerError,
}
}
func makeState() (string, error) {
uuid, err := uuid.NewV4()
if err != nil {
return "", err
}
return uuid.String(), nil
}
func checkState(request events.APIGatewayProxyRequest) error {
if state, ok := request.QueryStringParameters["state"]; ok {
if hasState := getStateFromCookie(request.Headers); hasState != nil {
if state == *hasState {
return nil
}
}
}
return errors.New("state is invalid")
}
func requestTokenByCode(code string, userPoolClientID string, userPoolClientSecret string, myselfURL string, userPoolURL string) (*TokenResponse, error) {
// https://docs.aws.amazon.com/ja_jp/cognito/latest/developerguide/token-endpoint.html
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("code", code)
form.Add("redirect_uri", myselfURL)
body := strings.NewReader(form.Encode())
req, err := http.NewRequest("POST", fmt.Sprintf("%s/oauth2/token", userPoolURL), body)
if err != nil {
return nil, err
}
req.Header.Add("Authorization",
fmt.Sprintf("Basic %s",
base64.StdEncoding.EncodeToString(
[]byte(fmt.Sprintf("%s:%s", userPoolClientID, userPoolClientSecret)))))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
var tokenResp TokenResponse
bytes, err := ioutil.ReadAll(res.Body)
if err := json.Unmarshal(bytes, &tokenResp); err != nil {
return nil, err
}
if tokenResp.Error != nil {
return nil, errors.New(*tokenResp.Error)
}
return &tokenResp, nil
}
func makeLoginURL(userPoolURL string, userPoolClientID string, myselfURL string) (string, string, error) {
// https://docs.aws.amazon.com/ja_jp/cognito/latest/developerguide/authorization-endpoint.html
u, err := url.Parse(fmt.Sprintf("%s/oauth2/authorize", userPoolURL))
if err != nil {
return "", "", errors.New("failed to parse")
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", userPoolClientID)
q.Set("redirect_uri", myselfURL)
q.Set("identity_provider", "Google")
state, err := makeState()
if err != nil {
return "", "", errors.New("failed to make state")
}
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String(), state, nil
}
func parseToken(tokenString string, userPoolClientID string, userPoolID string, region string) (jwt.MapClaims, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// RS256
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return getKey(token, userPoolID, region)
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// https://docs.aws.amazon.com/ja_jp/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
if v, ok := claims["exp"].(float64); ok {
if int64(v) < time.Now().Unix() {
return nil, errors.New("token is expired")
}
} else {
return nil, errors.New("failed to get claims[exp]")
}
// instead of aud
if v, ok := claims["client_id"].(string); ok {
if v != userPoolClientID {
return nil, errors.New("token has invalid audience")
}
} else {
return nil, errors.New("failed to get claims[client_id]")
}
if v, ok := claims["iss"].(string); ok {
if v != fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s", region, userPoolID) {
return nil, errors.New("token has invalid issuer")
}
} else {
return nil, errors.New("failed to get claims[iss]")
}
return claims, nil
}
return nil, errors.New("token is invalid")
}
func getKey(token *jwt.Token, userPoolID string, region string) (interface{}, error) {
set, err := jwk.FetchHTTP(
fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", region, userPoolID),
)
if err != nil {
return nil, err
}
keyID, ok := token.Header["kid"].(string)
if !ok {
return nil, errors.New("expecting JWT header to have string kid")
}
if key := set.LookupKeyID(keyID); len(key) == 1 {
return key[0].Materialize()
}
return nil, errors.New("unable to find key")
}