mirror of
https://github.com/restic/restic.git
synced 2024-09-09 21:00:59 +02:00
273 lines
7.7 KiB
Go
273 lines
7.7 KiB
Go
package main
|
|
|
|
// Copyright 2017 Microsoft Corporation
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/Azure/go-autorest/autorest"
|
|
"github.com/Azure/go-autorest/autorest/adal"
|
|
"github.com/Azure/go-autorest/autorest/azure"
|
|
"golang.org/x/crypto/pkcs12"
|
|
)
|
|
|
|
const (
|
|
resourceGroupURLTemplate = "https://management.azure.com"
|
|
apiVersion = "2015-01-01"
|
|
nativeAppClientID = "a87032a7-203c-4bf7-913c-44c50d23409a"
|
|
resource = "https://management.core.windows.net/"
|
|
)
|
|
|
|
var (
|
|
mode string
|
|
tenantID string
|
|
subscriptionID string
|
|
applicationID string
|
|
|
|
tokenCachePath string
|
|
forceRefresh bool
|
|
impatient bool
|
|
|
|
certificatePath string
|
|
)
|
|
|
|
func init() {
|
|
flag.StringVar(&mode, "mode", "device", "mode of operation for SPT creation")
|
|
flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/pfx certificate")
|
|
flag.StringVar(&applicationID, "applicationId", "", "application id")
|
|
flag.StringVar(&tenantID, "tenantId", "", "tenant id")
|
|
flag.StringVar(&subscriptionID, "subscriptionId", "", "subscription id")
|
|
flag.StringVar(&tokenCachePath, "tokenCachePath", "", "location of oauth token cache")
|
|
flag.BoolVar(&forceRefresh, "forceRefresh", false, "pass true to force a token refresh")
|
|
|
|
flag.Parse()
|
|
|
|
log.Printf("mode(%s) certPath(%s) appID(%s) tenantID(%s), subID(%s)\n",
|
|
mode, certificatePath, applicationID, tenantID, subscriptionID)
|
|
|
|
if mode == "certificate" &&
|
|
(strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "") {
|
|
log.Fatalln("Bad usage. Using certificate mode. Please specify tenantID, subscriptionID")
|
|
}
|
|
|
|
if mode != "certificate" && mode != "device" {
|
|
log.Fatalln("Bad usage. Mode must be one of 'certificate' or 'device'.")
|
|
}
|
|
|
|
if mode == "device" && strings.TrimSpace(applicationID) == "" {
|
|
log.Println("Using device mode auth. Will use `azkube` clientID since none was specified on the comand line.")
|
|
applicationID = nativeAppClientID
|
|
}
|
|
|
|
if mode == "certificate" && strings.TrimSpace(certificatePath) == "" {
|
|
log.Fatalln("Bad usage. Mode 'certificate' requires the 'certificatePath' argument.")
|
|
}
|
|
|
|
if strings.TrimSpace(tenantID) == "" || strings.TrimSpace(subscriptionID) == "" || strings.TrimSpace(applicationID) == "" {
|
|
log.Fatalln("Bad usage. Must specify the 'tenantId' and 'subscriptionId'")
|
|
}
|
|
}
|
|
|
|
func getSptFromCachedToken(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
|
|
token, err := adal.LoadToken(tokenCachePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load token from cache: %v", err)
|
|
}
|
|
|
|
spt, _ := adal.NewServicePrincipalTokenFromManualToken(
|
|
oauthConfig,
|
|
clientID,
|
|
resource,
|
|
*token,
|
|
callbacks...)
|
|
|
|
return spt, nil
|
|
}
|
|
|
|
func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
|
|
privateKey, certificate, err := pkcs12.Decode(pkcs, password)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
|
|
if !isRsaKey {
|
|
return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
|
|
}
|
|
|
|
return certificate, rsaPrivateKey, nil
|
|
}
|
|
|
|
func getSptFromCertificate(oauthConfig adal.OAuthConfig, clientID, resource, certicatePath string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
|
|
certData, err := ioutil.ReadFile(certificatePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
|
|
}
|
|
|
|
certificate, rsaPrivateKey, err := decodePkcs12(certData, "")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
|
|
}
|
|
|
|
spt, _ := adal.NewServicePrincipalTokenFromCertificate(
|
|
oauthConfig,
|
|
clientID,
|
|
certificate,
|
|
rsaPrivateKey,
|
|
resource,
|
|
callbacks...)
|
|
|
|
return spt, nil
|
|
}
|
|
|
|
func getSptFromDeviceFlow(oauthConfig adal.OAuthConfig, clientID, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {
|
|
oauthClient := &autorest.Client{}
|
|
deviceCode, err := adal.InitiateDeviceAuth(oauthClient, oauthConfig, clientID, resource)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to start device auth flow: %s", err)
|
|
}
|
|
|
|
fmt.Println(*deviceCode.Message)
|
|
|
|
token, err := adal.WaitForUserCompletion(oauthClient, deviceCode)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to finish device auth flow: %s", err)
|
|
}
|
|
|
|
spt, err := adal.NewServicePrincipalTokenFromManualToken(
|
|
oauthConfig,
|
|
clientID,
|
|
resource,
|
|
*token,
|
|
callbacks...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get oauth token from device flow: %v", err)
|
|
}
|
|
|
|
return spt, nil
|
|
}
|
|
|
|
func printResourceGroups(client *autorest.Client) error {
|
|
p := map[string]interface{}{"subscription-id": subscriptionID}
|
|
q := map[string]interface{}{"api-version": apiVersion}
|
|
|
|
req, _ := autorest.Prepare(&http.Request{},
|
|
autorest.AsGet(),
|
|
autorest.WithBaseURL(resourceGroupURLTemplate),
|
|
autorest.WithPathParameters("/subscriptions/{subscription-id}/resourcegroups", p),
|
|
autorest.WithQueryParameters(q))
|
|
|
|
resp, err := autorest.SendWithSender(client, req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
value := struct {
|
|
ResourceGroups []struct {
|
|
Name string `json:"name"`
|
|
} `json:"value"`
|
|
}{}
|
|
|
|
defer resp.Body.Close()
|
|
dec := json.NewDecoder(resp.Body)
|
|
err = dec.Decode(&value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var groupNames = make([]string, len(value.ResourceGroups))
|
|
for i, name := range value.ResourceGroups {
|
|
groupNames[i] = name.Name
|
|
}
|
|
|
|
log.Println("Groups:", strings.Join(groupNames, ", "))
|
|
return err
|
|
}
|
|
|
|
func saveToken(spt adal.Token) {
|
|
if tokenCachePath != "" {
|
|
err := adal.SaveToken(tokenCachePath, 0600, spt)
|
|
if err != nil {
|
|
log.Println("error saving token", err)
|
|
} else {
|
|
log.Println("saved token to", tokenCachePath)
|
|
}
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
var spt *adal.ServicePrincipalToken
|
|
var err error
|
|
|
|
callback := func(t adal.Token) error {
|
|
log.Println("refresh callback was called")
|
|
saveToken(t)
|
|
return nil
|
|
}
|
|
|
|
oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, tenantID)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if tokenCachePath != "" {
|
|
log.Println("tokenCachePath specified; attempting to load from", tokenCachePath)
|
|
spt, err = getSptFromCachedToken(*oauthConfig, applicationID, resource, callback)
|
|
if err != nil {
|
|
spt = nil // just in case, this is the condition below
|
|
log.Println("loading from cache failed:", err)
|
|
}
|
|
}
|
|
|
|
if spt == nil {
|
|
log.Println("authenticating via 'mode'", mode)
|
|
switch mode {
|
|
case "device":
|
|
spt, err = getSptFromDeviceFlow(*oauthConfig, applicationID, resource, callback)
|
|
case "certificate":
|
|
spt, err = getSptFromCertificate(*oauthConfig, applicationID, resource, certificatePath, callback)
|
|
}
|
|
if err != nil {
|
|
log.Fatalln("failed to retrieve token:", err)
|
|
}
|
|
|
|
// should save it as soon as you get it since Refresh won't be called for some time
|
|
if tokenCachePath != "" {
|
|
saveToken(spt.Token)
|
|
}
|
|
}
|
|
|
|
client := &autorest.Client{}
|
|
client.Authorizer = autorest.NewBearerAuthorizer(spt)
|
|
|
|
printResourceGroups(client)
|
|
|
|
if forceRefresh {
|
|
err = spt.Refresh()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
printResourceGroups(client)
|
|
}
|
|
}
|