mirror of
https://github.com/restic/restic.git
synced 2024-09-07 11:59:25 +02:00
208 lines
5.1 KiB
Go
208 lines
5.1 KiB
Go
|
// Copyright 2017, Google
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
// Package transport provides http.RoundTrippers that may be useful to clients
|
||
|
// of Blazer.
|
||
|
package transport
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"math/rand"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
// WithFailures returns an http.RoundTripper that wraps an existing
|
||
|
// RoundTripper, causing failures according to the options given. If rt is
|
||
|
// nil, the http.DefaultTransport is wrapped.
|
||
|
func WithFailures(rt http.RoundTripper, opts ...FailureOption) http.RoundTripper {
|
||
|
if rt == nil {
|
||
|
rt = http.DefaultTransport
|
||
|
}
|
||
|
o := &options{
|
||
|
rt: rt,
|
||
|
}
|
||
|
for _, opt := range opts {
|
||
|
opt(o)
|
||
|
}
|
||
|
return o
|
||
|
}
|
||
|
|
||
|
type options struct {
|
||
|
pathSubstrings []string
|
||
|
failureRate float64
|
||
|
status int
|
||
|
stall time.Duration
|
||
|
rt http.RoundTripper
|
||
|
msg string
|
||
|
trg *triggerReaderGroup
|
||
|
}
|
||
|
|
||
|
func (o *options) doRequest(req *http.Request) (*http.Response, error) {
|
||
|
if o.trg != nil && req.Body != nil {
|
||
|
req.Body = o.trg.new(req.Body)
|
||
|
}
|
||
|
resp, err := o.rt.RoundTrip(req)
|
||
|
if resp != nil && o.trg != nil {
|
||
|
resp.Body = o.trg.new(resp.Body)
|
||
|
}
|
||
|
return resp, err
|
||
|
}
|
||
|
|
||
|
func (o *options) RoundTrip(req *http.Request) (*http.Response, error) {
|
||
|
// TODO: fix triggering conditions
|
||
|
if rand.Float64() > o.failureRate {
|
||
|
return o.doRequest(req)
|
||
|
}
|
||
|
|
||
|
var match bool
|
||
|
if len(o.pathSubstrings) == 0 {
|
||
|
match = true
|
||
|
}
|
||
|
for _, ss := range o.pathSubstrings {
|
||
|
if strings.Contains(req.URL.Path, ss) {
|
||
|
match = true
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
if !match {
|
||
|
return o.doRequest(req)
|
||
|
}
|
||
|
|
||
|
if o.status > 0 {
|
||
|
resp := &http.Response{
|
||
|
Status: fmt.Sprintf("%d %s", o.status, http.StatusText(o.status)),
|
||
|
StatusCode: o.status,
|
||
|
Body: ioutil.NopCloser(strings.NewReader(o.msg)),
|
||
|
Request: req,
|
||
|
}
|
||
|
return resp, nil
|
||
|
}
|
||
|
|
||
|
if o.stall > 0 {
|
||
|
ctx := req.Context()
|
||
|
select {
|
||
|
case <-time.After(o.stall):
|
||
|
case <-ctx.Done():
|
||
|
}
|
||
|
}
|
||
|
return o.doRequest(req)
|
||
|
}
|
||
|
|
||
|
// A FailureOption specifies the kind of failure that the RoundTripper should
|
||
|
// display.
|
||
|
type FailureOption func(*options)
|
||
|
|
||
|
// MatchPathSubstring restricts the RoundTripper to URLs whose paths contain
|
||
|
// the given string. The default behavior is to match all paths.
|
||
|
func MatchPathSubstring(s string) FailureOption {
|
||
|
return func(o *options) {
|
||
|
o.pathSubstrings = append(o.pathSubstrings, s)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// FailureRate causes the RoundTripper to fail a certain percentage of the
|
||
|
// time. rate should be a number between 0 and 1, where 0 will never fail and
|
||
|
// 1 will always fail. The default is never to fail.
|
||
|
func FailureRate(rate float64) FailureOption {
|
||
|
return func(o *options) {
|
||
|
o.failureRate = rate
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Response simulates a given status code. The returned http.Response will
|
||
|
// have its Status, StatusCode, and Body (with any predefined message) set.
|
||
|
func Response(status int) FailureOption {
|
||
|
return func(o *options) {
|
||
|
o.status = status
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Stall simulates a network connection failure by stalling for the given
|
||
|
// duration.
|
||
|
func Stall(dur time.Duration) FailureOption {
|
||
|
return func(o *options) {
|
||
|
o.stall = dur
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// If a specific Response is requested, the body will have the given message
|
||
|
// set.
|
||
|
func Body(msg string) FailureOption {
|
||
|
return func(o *options) {
|
||
|
o.msg = msg
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Trigger will raise the RoundTripper's failure rate to 100% when the given
|
||
|
// context is closed.
|
||
|
func Trigger(ctx context.Context) FailureOption {
|
||
|
return func(o *options) {
|
||
|
go func() {
|
||
|
<-ctx.Done()
|
||
|
o.failureRate = 1
|
||
|
}()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// AfterNBytes will call effect once (roughly) n bytes have gone over the wire.
|
||
|
// Both sent and received bytes are counted against the total. Only bytes in
|
||
|
// the body of an HTTP request are currently counted; this may change in the
|
||
|
// future. effect will only be called once, and it will block (allowing
|
||
|
// callers to simulate connection hangs).
|
||
|
func AfterNBytes(n int, effect func()) FailureOption {
|
||
|
return func(o *options) {
|
||
|
o.trg = &triggerReaderGroup{
|
||
|
bytes: int64(n),
|
||
|
trigger: effect,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type triggerReaderGroup struct {
|
||
|
bytes int64
|
||
|
trigger func()
|
||
|
triggered int64
|
||
|
}
|
||
|
|
||
|
func (rg *triggerReaderGroup) new(rc io.ReadCloser) io.ReadCloser {
|
||
|
return &triggerReader{
|
||
|
ReadCloser: rc,
|
||
|
bytes: &rg.bytes,
|
||
|
trigger: rg.trigger,
|
||
|
triggered: &rg.triggered,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type triggerReader struct {
|
||
|
io.ReadCloser
|
||
|
bytes *int64
|
||
|
trigger func()
|
||
|
triggered *int64
|
||
|
}
|
||
|
|
||
|
func (r *triggerReader) Read(p []byte) (int, error) {
|
||
|
n, err := r.ReadCloser.Read(p)
|
||
|
if atomic.AddInt64(r.bytes, -int64(n)) < 0 && atomic.CompareAndSwapInt64(r.triggered, 0, 1) {
|
||
|
// Can't use sync.Once because it blocks for *all* callers until Do returns.
|
||
|
r.trigger()
|
||
|
}
|
||
|
return n, err
|
||
|
}
|