API GatewayでCognitoの認証をかけて必要ならログイン画面に飛ばす処理をGoで書く

awsgolangauth

ブラウザから直接API GatewayのエンドポイントにアクセスしたときにCognitoのTokenで認証し、失敗したらログイン画面を表示させる。 API GatewayでCognitoの認証をかける場合、AuthorizerでUserPoolを指定するのが最も簡単なパターンだが、 これだとHeaderにTokenを付けてアクセスする必要があり認証に失敗するとUnauthorizedが返る。

Cognito UserPoolとAPI Gatewayで認証付きAPIを立てる - sambaiz-net

なおAPI GatewayではなくALBをLambdaの前段に挟めば今回やることが簡単に実現できる。

LambdaとALBでCognito認証をかけて失敗したらログイン画面に飛ばす - sambaiz-net

準備

UserPoolとClientを作成する。 CloudFormationで作成する場合SchemaのMutableのデフォルトがfalseで、変えると作り直されてしまうことに注意。

Resources:
  Userpool:
    Type: AWS::Cognito::UserPool
    Properties:
      AdminCreateUserConfig:
        AllowAdminCreateUserOnly: false
      Schema:
        - Mutable: true
          Name: email
          Required: true
        - Mutable: true
          Name: name
          Required: true
      UsernameAttributes:
        - email
      UserPoolName: testpool
  UserpoolClient:
    Type: AWS::Cognito::UserPoolClient
    Properties:
      UserPoolId:
        Ref: Userpool
      ClientName: testclient
      GenerateSecret: true

その後、GoogleのOAuth Client IDを作成し、フェデレーションの設定を行ってGoogleアカウントでもログインできるようにした。 これはID Poolのフェデレーティッドアイデンティティとは異なる機能。 UserPoolのドメインや、外部IdPとのAttributes Mapping、Clientの設定はCloudFormationではできないので手で行う。

追記 (2020-12-06): 今はCloudFormationで行えるようになっている。

CDKでCognito UserPoolとClientを作成しトリガーやFederationを設定する - sambaiz-net

Attributes Mappingにusernameの項目があるが、変更不可な値なのでsubのままにしておく。

設定が終わったら https://<user-pool-domain>/login?client_id=<client-id>&redirect_uri=<redirect-uri>&response_type=code にアクセスし、<redirect-uri>?code=<code> までリダイレクトされることを確認する。 invalid_requestとだけ出てしまった場合はClientのAuthorization code grantにチェックが入っているか確認する。

実装

Authorization Flowを行う。

OAuth2.0のメモ - sambaiz-net

Cookieで渡ってきたTokenを検証し、失敗した場合はログイン画面にリダイレクトさせる。 identity_providerを指定するかClientに一つしか紐づいていなければそのIdPのログイン画面に飛び、そうでなければ選択する画面が表示される。 ログインしてクエリパラメータとして渡ってきたcodeとstateもこの関数で受け取り、検証してTokenを発行し一旦返してCookieに焼く。

OpenID ConnectのIDトークンの内容と検証 - sambaiz-net

通常stateはサーバーに保持されるがDBを持たなくていいようにCookieに焼いている。 CSRF対策のためのものなのでCookieが攻撃者に読み書きされなければ意味を為すと思っているが、このフローで問題がある場合は教えてほしい。 直接アクセスすることを想定してTokenをCookieに焼いているので、CORSの、特にAccess-Control-Allow-Origin: *のような設定を行うと穴が開いてしまう。

package auth

import (
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"net/url"
	"strings"
	"time"

	"github.com/gofrs/uuid"

	"github.com/aws/aws-lambda-go/events"
	jwt "github.com/dgrijalva/jwt-go"
	"github.com/lestrrat/go-jwx/jwk"
)

const (
	stateCookieName = "st"
	tokenCookieName = "ck"
)

type TokenResponse struct {
	AccessToken  string  `json:"access_token"`
	RefreshToken string  `json:"refresh_token"`
	IdToken      string  `json:"id_token"`
	TokenType    string  `json:"token_type"`
	ExpiresIn    int     `json:"expires_in"`
	Error        *string `json:"error"`
}

func AuthOrLogin(
	request events.APIGatewayProxyRequest,
	userPoolClientID string,
	userPoolClientSecret string,
	userPoolURL string,
	userPoolID string,
	region string) *events.APIGatewayProxyResponse {
	myselfURL := fmt.Sprintf("https://%s/%s%s", request.Headers["Host"], request.RequestContext.Stage, request.Path)

	if code := getCodeFromParams(request.QueryStringParameters); code != nil {
		if err := checkState(request); err != nil {
			fmt.Printf(err.Error())
			return newErrorResponse()
		}
		tokenResp, err := requestTokenByCode(*code, userPoolClientID, userPoolClientSecret, myselfURL, userPoolURL)
		if err != nil {
			fmt.Println(err.Error())
			return newErrorResponse()
		}
		resp := newRedirectResponse(myselfURL)
		resp.MultiValueHeaders = map[string][]string{
			"Set-Cookie": []string{
				fmt.Sprintf("%s=%s; Secure; HttpOnly", tokenCookieName, tokenResp.AccessToken),
				// delete state
				fmt.Sprintf("%s=; Expires=Thu, 1-Jan-1990 00:00:00 GMT; Secure; HttpOnly", stateCookieName),
			},
		}
		return resp
	}

	if token := getTokenFromCookie(request.Headers); token != nil {
		claims, err := parseToken(*token, userPoolClientID, userPoolID, region)
		if err == nil {
			fmt.Println(claims)
			return nil // ok
		} else {
			fmt.Println(err.Error()) // re-login
		}
	}

	loginURL, state, err := makeLoginURL(userPoolURL, userPoolClientID, myselfURL)
	if err != nil {
		fmt.Println(err.Error())
		return newErrorResponse()

	}
	resp := newRedirectResponse(loginURL)
	resp.MultiValueHeaders = map[string][]string{
		"Set-Cookie": []string{
			// delete old token
			fmt.Sprintf("%s=; Expires=Thu, 1-Jan-1990 00:00:00 GMT; Secure; HttpOnly", tokenCookieName),
			fmt.Sprintf("%s=%s; Secure; HttpOnly", stateCookieName, state),
		},
	}
	return resp
}

func getCodeFromParams(params map[string]string) *string {
	if code, ok := params["code"]; ok {
		return &code
	}
	return nil
}

func getStateFromCookie(headers map[string]string) *string {
	if cookie, ok := headers["Cookie"]; ok {
		header := http.Header{}
		header.Add("Cookie", cookie)
		request := http.Request{Header: header}
		c, err := request.Cookie(stateCookieName)
		if err != nil {
			return nil
		}
		return &c.Value
	}
	return nil
}

func getTokenFromCookie(headers map[string]string) *string {
	if cookie, ok := headers["Cookie"]; ok {
		header := http.Header{}
		header.Add("Cookie", cookie)
		request := http.Request{Header: header}
		c, err := request.Cookie(tokenCookieName)
		if err != nil {
			return nil
		}
		return &c.Value
	}
	return nil
}

func newRedirectResponse(location string) *events.APIGatewayProxyResponse {
	return &events.APIGatewayProxyResponse{
		StatusCode: http.StatusFound,
		Headers: map[string]string{
			"Location": location,
		},
	}
}

func newErrorResponse() *events.APIGatewayProxyResponse {
	return &events.APIGatewayProxyResponse{
		StatusCode: http.StatusInternalServerError,
	}
}

func makeState() (string, error) {
	uuid, err := uuid.NewV4()
	if err != nil {
		return "", err
	}
	return uuid.String(), nil
}

func checkState(request events.APIGatewayProxyRequest) error {
	if state, ok := request.QueryStringParameters["state"]; ok {
		if hasState := getStateFromCookie(request.Headers); hasState != nil {
			if state == *hasState {
				return nil
			}
		}
	}
	return errors.New("state is invalid")
}

func requestTokenByCode(code string, userPoolClientID string, userPoolClientSecret string, myselfURL string, userPoolURL string) (*TokenResponse, error) {
	// https://docs.aws.amazon.com/ja_jp/cognito/latest/developerguide/token-endpoint.html
	form := url.Values{}
	form.Add("grant_type", "authorization_code")
	form.Add("code", code)
	form.Add("redirect_uri", myselfURL)
	body := strings.NewReader(form.Encode())
	req, err := http.NewRequest("POST", fmt.Sprintf("%s/oauth2/token", userPoolURL), body)
	if err != nil {
		return nil, err
	}
	req.Header.Add("Authorization",
		fmt.Sprintf("Basic %s",
			base64.StdEncoding.EncodeToString(
				[]byte(fmt.Sprintf("%s:%s", userPoolClientID, userPoolClientSecret)))))
	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
	res, err := http.DefaultClient.Do(req)
	if err != nil {
		return nil, err
	}
	defer res.Body.Close()
	var tokenResp TokenResponse
	bytes, err := ioutil.ReadAll(res.Body)
	if err := json.Unmarshal(bytes, &tokenResp); err != nil {
		return nil, err
	}
	if tokenResp.Error != nil {
		return nil, errors.New(*tokenResp.Error)
	}
	return &tokenResp, nil
}

func makeLoginURL(userPoolURL string, userPoolClientID string, myselfURL string) (string, string, error) {
	// https://docs.aws.amazon.com/ja_jp/cognito/latest/developerguide/authorization-endpoint.html
	u, err := url.Parse(fmt.Sprintf("%s/oauth2/authorize", userPoolURL))
	if err != nil {
		return "", "", errors.New("failed to parse")
	}
	q := u.Query()
	q.Set("response_type", "code")
	q.Set("client_id", userPoolClientID)
	q.Set("redirect_uri", myselfURL)
	q.Set("identity_provider", "Google")
	state, err := makeState()
	if err != nil {
		return "", "", errors.New("failed to make state")
	}
	q.Set("state", state)
	u.RawQuery = q.Encode()
	return u.String(), state, nil
}

func parseToken(tokenString string, userPoolClientID string, userPoolID string, region string) (jwt.MapClaims, error) {
	token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
		// RS256
		if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
		}
		return getKey(token, userPoolID, region)
	})
	if err != nil {
		return nil, err
	}
	if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
		// https://docs.aws.amazon.com/ja_jp/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
		if v, ok := claims["exp"].(float64); ok {
			if int64(v) < time.Now().Unix() {
				return nil, errors.New("token is expired")
			}
		} else {
			return nil, errors.New("failed to get claims[exp]")
		}
		// instead of aud
		if v, ok := claims["client_id"].(string); ok {
			if v != userPoolClientID {
				return nil, errors.New("token has invalid audience")
			}
		} else {
			return nil, errors.New("failed to get claims[client_id]")
		}
		if v, ok := claims["iss"].(string); ok {
			if v != fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s", region, userPoolID) {
				return nil, errors.New("token has invalid issuer")
			}
		} else {
			return nil, errors.New("failed to get claims[iss]")
		}

		return claims, nil
	}
	return nil, errors.New("token is invalid")
}

func getKey(token *jwt.Token, userPoolID string, region string) (interface{}, error) {
	set, err := jwk.FetchHTTP(
		fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", region, userPoolID),
	)
	if err != nil {
		return nil, err
	}

	keyID, ok := token.Header["kid"].(string)
	if !ok {
		return nil, errors.New("expecting JWT header to have string kid")
	}

	if key := set.LookupKeyID(keyID); len(key) == 1 {
		return key[0].Materialize()
	}

	return nil, errors.New("unable to find key")
}