API GatewayでCognitoの認証をかけて必要ならログイン画面に飛ばす処理をGoで書く

(2019-07-03)

ブラウザから直接API GatewayのエンドポイントにアクセスしたときにCognitoのTokenで認証し、失敗したらログイン画面を表示させる。 API GatewayでCognitoの認証をかける場合、AuthorizerでUserPoolを指定するのが最も簡単なパターンだが、 これだとHeaderにTokenを付けてアクセスする必要があり認証に失敗するとUnauthorizedが返る。

Cognito UserPoolとAPI Gatewayで認証付きAPIを立てる - sambaiz-net

なおAPI GatewayではなくALBをLambdaの前段に挟めば今回やることが簡単に実現できる。

LambdaとALBでCognito認証をかけて失敗したらログイン画面に飛ばす - sambaiz-net

準備

UserPoolとClientを作成する。 CloudFormationで作成する場合SchemaのMutableのデフォルトがfalseなのに注意。変えると作り直される。

Resources:
  Userpool:
    Type: AWS::Cognito::UserPool
    Properties:
      AdminCreateUserConfig:
        AllowAdminCreateUserOnly: false
      Schema:
        - Mutable: true
          Name: email
          Required: true
        - Mutable: true
          Name: name
          Required: true
      UsernameAttributes:
        - email
      UserPoolName: testpool
  UserpoolClient:
    Type: AWS::Cognito::UserPoolClient
    Properties:
      UserPoolId:
        Ref: Userpool
      ClientName: testclient
      GenerateSecret: true

今回はフェデレーションでGoogleアカウントでログインできるようにする。 ややこしい用語であるがID Poolのフェデレーティッドアイデンティティとは異なる。

UserPoolのドメインや、GoogleのOAuth Client IDの発行とAttributes Mapping、Clientの設定はCloudFormationでできないので手でやる。 Attributes Mappingで変えられるがusernameは変更不可な値なのでsubのままにしておく。

設定が終わったら https://<user-pool-domain>/login?client_id=<client-id>&redirect_uri=<redirect-uri>&response_type=code にアクセスし、<redirect-uri>?code=<code> までリダイレクトされることを確認する。 invalid_requestとだけ出てしまった場合はClientのAuthorization code grantにチェックが入っているか確認する。

実装

Authorization Flowを行う。

OAuth2.0のメモ - sambaiz-net

Cookieで渡ってきたTokenを検証し、失敗した場合はログイン画面にリダイレクトさせる。 identity_providerを指定するかClientに一つしか紐づいていなければそのプロバイダのログイン画面に飛び、そうでなければ選択する画面が表示される。 ログインしてクエリパラメータとして渡ってきたcodeとstateもこの関数で受け取り、検証してTokenを発行し一旦返してCookieに焼く。

OpenID ConnectのIDトークンの内容と検証 - sambaiz-net

通常stateはサーバーに保持されるがDBを持たなくていいようにCookieに焼いている。 CSRF対策のためのものなのでCookieが攻撃者に読み書きされなければ意味を為すと思っているが、このフローで問題がある場合は教えてほしい。 State、Token共にCookieに焼いているためCORSのAccess-Control-Allow-Originを*などにしてしまうと穴が開く。

package auth

import (
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"net/url"
	"strings"
	"time"

	"github.com/gofrs/uuid"

	"github.com/aws/aws-lambda-go/events"
	jwt "github.com/dgrijalva/jwt-go"
	"github.com/lestrrat/go-jwx/jwk"
)

const (
	stateCookieName = "st"
	tokenCookieName = "ck"
)

type TokenResponse struct {
	AccessToken  string  `json:"access_token"`
	RefreshToken string  `json:"refresh_token"`
	IdToken      string  `json:"id_token"`
	TokenType    string  `json:"token_type"`
	ExpiresIn    int     `json:"expires_in"`
	Error        *string `json:"error"`
}

func AuthOrLogin(
	request events.APIGatewayProxyRequest,
	userPoolClientID string,
	userPoolClientSecret string,
	userPoolURL string,
	userPoolID string,
	region string) *events.APIGatewayProxyResponse {
	myselfURL := fmt.Sprintf("https://%s/%s%s", request.Headers["Host"], request.RequestContext.Stage, request.Path)

	if code := getCodeFromParams(request.QueryStringParameters); code != nil {
		if err := checkState(request); err != nil {
			fmt.Printf(err.Error())
			return newErrorResponse()
		}
		tokenResp, err := requestTokenByCode(*code, userPoolClientID, userPoolClientSecret, myselfURL, userPoolURL)
		if err != nil {
			fmt.Println(err.Error())
			return newErrorResponse()
		}
		resp := newRedirectResponse(myselfURL)
		resp.MultiValueHeaders = map[string][]string{
			"Set-Cookie": []string{
				fmt.Sprintf("%s=%s; Secure; HttpOnly", tokenCookieName, tokenResp.AccessToken),
				// delete state
				fmt.Sprintf("%s=; Expires=Thu, 1-Jan-1990 00:00:00 GMT; Secure; HttpOnly", stateCookieName),
			},
		}
		return resp
	}

	if token := getTokenFromCookie(request.Headers); token != nil {
		claims, err := parseToken(*token, userPoolClientID, userPoolID, region)
		if err == nil {
			fmt.Println(claims)
			return nil // ok
		} else {
			fmt.Println(err.Error()) // re-login
		}
	}

	loginURL, state, err := makeLoginURL(userPoolURL, userPoolClientID, myselfURL)
	if err != nil {
		fmt.Println(err.Error())
		return newErrorResponse()

	}
	resp := newRedirectResponse(loginURL)
	resp.MultiValueHeaders = map[string][]string{
		"Set-Cookie": []string{
			// delete old token
			fmt.Sprintf("%s=; Expires=Thu, 1-Jan-1990 00:00:00 GMT; Secure; HttpOnly", tokenCookieName),
			fmt.Sprintf("%s=%s; Secure; HttpOnly", stateCookieName, state),
		},
	}
	return resp
}

func getCodeFromParams(params map[string]string) *string {
	if code, ok := params["code"]; ok {
		return &code
	}
	return nil
}

func getStateFromCookie(headers map[string]string) *string {
	if cookie, ok := headers["Cookie"]; ok {
		header := http.Header{}
		header.Add("Cookie", cookie)
		request := http.Request{Header: header}
		c, err := request.Cookie(stateCookieName)
		if err != nil {
			return nil
		}
		return &c.Value
	}
	return nil
}

func getTokenFromCookie(headers map[string]string) *string {
	if cookie, ok := headers["Cookie"]; ok {
		header := http.Header{}
		header.Add("Cookie", cookie)
		request := http.Request{Header: header}
		c, err := request.Cookie(tokenCookieName)
		if err != nil {
			return nil
		}
		return &c.Value
	}
	return nil
}

func newRedirectResponse(location string) *events.APIGatewayProxyResponse {
	return &events.APIGatewayProxyResponse{
		StatusCode: http.StatusFound,
		Headers: map[string]string{
			"Location": location,
		},
	}
}

func newErrorResponse() *events.APIGatewayProxyResponse {
	return &events.APIGatewayProxyResponse{
		StatusCode: http.StatusInternalServerError,
	}
}

func makeState() (string, error) {
	uuid, err := uuid.NewV4()
	if err != nil {
		return "", err
	}
	return uuid.String(), nil
}

func checkState(request events.APIGatewayProxyRequest) error {
	if state, ok := request.QueryStringParameters["state"]; ok {
		if hasState := getStateFromCookie(request.Headers); hasState != nil {
			if state == *hasState {
				return nil
			}
		}
	}
	return errors.New("state is invalid")
}

func requestTokenByCode(code string, userPoolClientID string, userPoolClientSecret string, myselfURL string, userPoolURL string) (*TokenResponse, error) {
	// https://docs.aws.amazon.com/ja_jp/cognito/latest/developerguide/token-endpoint.html
	form := url.Values{}
	form.Add("grant_type", "authorization_code")
	form.Add("code", code)
	form.Add("redirect_uri", myselfURL)
	body := strings.NewReader(form.Encode())
	req, err := http.NewRequest("POST", fmt.Sprintf("%s/oauth2/token", userPoolURL), body)
	if err != nil {
		return nil, err
	}
	req.Header.Add("Authorization",
		fmt.Sprintf("Basic %s",
			base64.StdEncoding.EncodeToString(
				[]byte(fmt.Sprintf("%s:%s", userPoolClientID, userPoolClientSecret)))))
	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
	res, err := http.DefaultClient.Do(req)
	if err != nil {
		return nil, err
	}
	defer res.Body.Close()
	var tokenResp TokenResponse
	bytes, err := ioutil.ReadAll(res.Body)
	if err := json.Unmarshal(bytes, &tokenResp); err != nil {
		return nil, err
	}
	if tokenResp.Error != nil {
		return nil, errors.New(*tokenResp.Error)
	}
	return &tokenResp, nil
}

func makeLoginURL(userPoolURL string, userPoolClientID string, myselfURL string) (string, string, error) {
	// https://docs.aws.amazon.com/ja_jp/cognito/latest/developerguide/authorization-endpoint.html
	u, err := url.Parse(fmt.Sprintf("%s/oauth2/authorize", userPoolURL))
	if err != nil {
		return "", "", errors.New("failed to parse")
	}
	q := u.Query()
	q.Set("response_type", "code")
	q.Set("client_id", userPoolClientID)
	q.Set("redirect_uri", myselfURL)
	q.Set("identity_provider", "Google")
	state, err := makeState()
	if err != nil {
		return "", "", errors.New("failed to make state")
	}
	q.Set("state", state)
	u.RawQuery = q.Encode()
	return u.String(), state, nil
}

func parseToken(tokenString string, userPoolClientID string, userPoolID string, region string) (jwt.MapClaims, error) {
	token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
		// RS256
		if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
		}
		return getKey(token, userPoolID, region)
	})
	if err != nil {
		return nil, err
	}
	if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
		// https://docs.aws.amazon.com/ja_jp/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
		if v, ok := claims["exp"].(float64); ok {
			if int64(v) < time.Now().Unix() {
				return nil, errors.New("token is expired")
			}
		} else {
			return nil, errors.New("failed to get claims[exp]")
		}
		// instead of aud
		if v, ok := claims["client_id"].(string); ok {
			if v != userPoolClientID {
				return nil, errors.New("token has invalid audience")
			}
		} else {
			return nil, errors.New("failed to get claims[client_id]")
		}
		if v, ok := claims["iss"].(string); ok {
			if v != fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s", region, userPoolID) {
				return nil, errors.New("token has invalid issuer")
			}
		} else {
			return nil, errors.New("failed to get claims[iss]")
		}

		return claims, nil
	}
	return nil, errors.New("token is invalid")
}

func getKey(token *jwt.Token, userPoolID string, region string) (interface{}, error) {
	set, err := jwk.FetchHTTP(
		fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", region, userPoolID),
	)
	if err != nil {
		return nil, err
	}

	keyID, ok := token.Header["kid"].(string)
	if !ok {
		return nil, errors.New("expecting JWT header to have string kid")
	}

	if key := set.LookupKeyID(keyID); len(key) == 1 {
		return key[0].Materialize()
	}

	return nil, errors.New("unable to find key")
}