2017-08-05 20:30:20 +02:00
|
|
|
package adal
|
|
|
|
|
2017-10-22 10:07:36 +02:00
|
|
|
// Copyright 2017 Microsoft Corporation
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
|
2017-08-05 20:30:20 +02:00
|
|
|
/*
|
|
|
|
This file is largely based on rjw57/oauth2device's code, with the follow differences:
|
|
|
|
* scope -> resource, and only allow a single one
|
|
|
|
* receive "Message" in the DeviceCode struct and show it to users as the prompt
|
|
|
|
* azure-xplat-cli has the following behavior that this emulates:
|
|
|
|
- does not send client_secret during the token exchange
|
|
|
|
- sends resource again in the token exchange request
|
|
|
|
*/
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
|
|
|
"net/http"
|
|
|
|
"net/url"
|
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
logPrefix = "autorest/adal/devicetoken:"
|
|
|
|
)
|
|
|
|
|
|
|
|
var (
|
|
|
|
// ErrDeviceGeneric represents an unknown error from the token endpoint when using device flow
|
|
|
|
ErrDeviceGeneric = fmt.Errorf("%s Error while retrieving OAuth token: Unknown Error", logPrefix)
|
|
|
|
|
|
|
|
// ErrDeviceAccessDenied represents an access denied error from the token endpoint when using device flow
|
|
|
|
ErrDeviceAccessDenied = fmt.Errorf("%s Error while retrieving OAuth token: Access Denied", logPrefix)
|
|
|
|
|
|
|
|
// ErrDeviceAuthorizationPending represents the server waiting on the user to complete the device flow
|
|
|
|
ErrDeviceAuthorizationPending = fmt.Errorf("%s Error while retrieving OAuth token: Authorization Pending", logPrefix)
|
|
|
|
|
|
|
|
// ErrDeviceCodeExpired represents the server timing out and expiring the code during device flow
|
|
|
|
ErrDeviceCodeExpired = fmt.Errorf("%s Error while retrieving OAuth token: Code Expired", logPrefix)
|
|
|
|
|
|
|
|
// ErrDeviceSlowDown represents the service telling us we're polling too often during device flow
|
|
|
|
ErrDeviceSlowDown = fmt.Errorf("%s Error while retrieving OAuth token: Slow Down", logPrefix)
|
|
|
|
|
|
|
|
// ErrDeviceCodeEmpty represents an empty device code from the device endpoint while using device flow
|
|
|
|
ErrDeviceCodeEmpty = fmt.Errorf("%s Error while retrieving device code: Device Code Empty", logPrefix)
|
|
|
|
|
|
|
|
// ErrOAuthTokenEmpty represents an empty OAuth token from the token endpoint when using device flow
|
|
|
|
ErrOAuthTokenEmpty = fmt.Errorf("%s Error while retrieving OAuth token: Token Empty", logPrefix)
|
|
|
|
|
|
|
|
errCodeSendingFails = "Error occurred while sending request for Device Authorization Code"
|
|
|
|
errCodeHandlingFails = "Error occurred while handling response from the Device Endpoint"
|
|
|
|
errTokenSendingFails = "Error occurred while sending request with device code for a token"
|
|
|
|
errTokenHandlingFails = "Error occurred while handling response from the Token Endpoint (during device flow)"
|
|
|
|
errStatusNotOK = "Error HTTP status != 200"
|
|
|
|
)
|
|
|
|
|
|
|
|
// DeviceCode is the object returned by the device auth endpoint
|
|
|
|
// It contains information to instruct the user to complete the auth flow
|
|
|
|
type DeviceCode struct {
|
|
|
|
DeviceCode *string `json:"device_code,omitempty"`
|
|
|
|
UserCode *string `json:"user_code,omitempty"`
|
|
|
|
VerificationURL *string `json:"verification_url,omitempty"`
|
|
|
|
ExpiresIn *int64 `json:"expires_in,string,omitempty"`
|
|
|
|
Interval *int64 `json:"interval,string,omitempty"`
|
|
|
|
|
|
|
|
Message *string `json:"message"` // Azure specific
|
|
|
|
Resource string // store the following, stored when initiating, used when exchanging
|
|
|
|
OAuthConfig OAuthConfig
|
|
|
|
ClientID string
|
|
|
|
}
|
|
|
|
|
|
|
|
// TokenError is the object returned by the token exchange endpoint
|
|
|
|
// when something is amiss
|
|
|
|
type TokenError struct {
|
|
|
|
Error *string `json:"error,omitempty"`
|
|
|
|
ErrorCodes []int `json:"error_codes,omitempty"`
|
|
|
|
ErrorDescription *string `json:"error_description,omitempty"`
|
|
|
|
Timestamp *string `json:"timestamp,omitempty"`
|
|
|
|
TraceID *string `json:"trace_id,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
// DeviceToken is the object return by the token exchange endpoint
|
|
|
|
// It can either look like a Token or an ErrorToken, so put both here
|
|
|
|
// and check for presence of "Error" to know if we are in error state
|
|
|
|
type deviceToken struct {
|
|
|
|
Token
|
|
|
|
TokenError
|
|
|
|
}
|
|
|
|
|
|
|
|
// InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode
|
|
|
|
// that can be used with CheckForUserCompletion or WaitForUserCompletion.
|
|
|
|
func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
|
|
|
|
v := url.Values{
|
|
|
|
"client_id": []string{clientID},
|
|
|
|
"resource": []string{resource},
|
|
|
|
}
|
|
|
|
|
|
|
|
s := v.Encode()
|
|
|
|
body := ioutil.NopCloser(strings.NewReader(s))
|
|
|
|
|
|
|
|
req, err := http.NewRequest(http.MethodPost, oauthConfig.DeviceCodeEndpoint.String(), body)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
|
|
|
|
}
|
|
|
|
|
|
|
|
req.ContentLength = int64(len(s))
|
|
|
|
req.Header.Set(contentType, mimeTypeFormPost)
|
|
|
|
resp, err := sender.Do(req)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
rb, err := ioutil.ReadAll(resp.Body)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
|
|
|
|
}
|
|
|
|
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
|
|
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, errStatusNotOK)
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(strings.Trim(string(rb), " ")) == 0 {
|
|
|
|
return nil, ErrDeviceCodeEmpty
|
|
|
|
}
|
|
|
|
|
|
|
|
var code DeviceCode
|
|
|
|
err = json.Unmarshal(rb, &code)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeHandlingFails, err.Error())
|
|
|
|
}
|
|
|
|
|
|
|
|
code.ClientID = clientID
|
|
|
|
code.Resource = resource
|
|
|
|
code.OAuthConfig = oauthConfig
|
|
|
|
|
|
|
|
return &code, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint
|
|
|
|
// to see if the device flow has: been completed, timed out, or otherwise failed
|
|
|
|
func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
|
|
|
|
v := url.Values{
|
|
|
|
"client_id": []string{code.ClientID},
|
|
|
|
"code": []string{*code.DeviceCode},
|
|
|
|
"grant_type": []string{OAuthGrantTypeDeviceCode},
|
|
|
|
"resource": []string{code.Resource},
|
|
|
|
}
|
|
|
|
|
|
|
|
s := v.Encode()
|
|
|
|
body := ioutil.NopCloser(strings.NewReader(s))
|
|
|
|
|
|
|
|
req, err := http.NewRequest(http.MethodPost, code.OAuthConfig.TokenEndpoint.String(), body)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
|
|
|
|
}
|
|
|
|
|
|
|
|
req.ContentLength = int64(len(s))
|
|
|
|
req.Header.Set(contentType, mimeTypeFormPost)
|
|
|
|
resp, err := sender.Do(req)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
rb, err := ioutil.ReadAll(resp.Body)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
|
|
|
|
}
|
|
|
|
|
|
|
|
if resp.StatusCode != http.StatusOK && len(strings.Trim(string(rb), " ")) == 0 {
|
|
|
|
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, errStatusNotOK)
|
|
|
|
}
|
|
|
|
if len(strings.Trim(string(rb), " ")) == 0 {
|
|
|
|
return nil, ErrOAuthTokenEmpty
|
|
|
|
}
|
|
|
|
|
|
|
|
var token deviceToken
|
|
|
|
err = json.Unmarshal(rb, &token)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenHandlingFails, err.Error())
|
|
|
|
}
|
|
|
|
|
|
|
|
if token.Error == nil {
|
|
|
|
return &token.Token, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
switch *token.Error {
|
|
|
|
case "authorization_pending":
|
|
|
|
return nil, ErrDeviceAuthorizationPending
|
|
|
|
case "slow_down":
|
|
|
|
return nil, ErrDeviceSlowDown
|
|
|
|
case "access_denied":
|
|
|
|
return nil, ErrDeviceAccessDenied
|
|
|
|
case "code_expired":
|
|
|
|
return nil, ErrDeviceCodeExpired
|
|
|
|
default:
|
|
|
|
return nil, ErrDeviceGeneric
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs.
|
|
|
|
// This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
|
|
|
|
func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
|
|
|
|
intervalDuration := time.Duration(*code.Interval) * time.Second
|
|
|
|
waitDuration := intervalDuration
|
|
|
|
|
|
|
|
for {
|
|
|
|
token, err := CheckForUserCompletion(sender, code)
|
|
|
|
|
|
|
|
if err == nil {
|
|
|
|
return token, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
switch err {
|
|
|
|
case ErrDeviceSlowDown:
|
|
|
|
waitDuration += waitDuration
|
|
|
|
case ErrDeviceAuthorizationPending:
|
|
|
|
// noop
|
|
|
|
default: // everything else is "fatal" to us
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
if waitDuration > (intervalDuration * 3) {
|
|
|
|
return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix)
|
|
|
|
}
|
|
|
|
|
|
|
|
time.Sleep(waitDuration)
|
|
|
|
}
|
|
|
|
}
|