251 lines
7.7 KiB
Go
251 lines
7.7 KiB
Go
package fa
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// newTestTransport builds a transport that wraps a base RoundTripper for
|
|
// httptest use. Rate limit is fast (no waits between requests) so retry
|
|
// tests run in subseconds.
|
|
func newTestTransport(base http.RoundTripper, maxRetries int) *transport {
|
|
return &transport{
|
|
base: base,
|
|
limiter: newRateLimiter(time.Microsecond, 16, false),
|
|
userAgent: "test-agent",
|
|
maxRetries: maxRetries,
|
|
}
|
|
}
|
|
|
|
func TestTransport_DetectsCloudflareChallenge_FromHeader(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("cf-mitigated", "challenge")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 3)}
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
|
_, err := c.Do(req)
|
|
if !errors.Is(err, ErrCloudflareChallenge) {
|
|
t.Fatalf("got %v; want ErrCloudflareChallenge", err)
|
|
}
|
|
}
|
|
|
|
func TestTransport_DetectsCloudflareChallenge_From503CFRay(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("cf-ray", "abc123-FRA")
|
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
_, _ = w.Write([]byte("<html>Just a moment...</html>"))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 3)}
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
|
_, err := c.Do(req)
|
|
if !errors.Is(err, ErrCloudflareChallenge) {
|
|
t.Fatalf("got %v; want ErrCloudflareChallenge", err)
|
|
}
|
|
}
|
|
|
|
func TestTransport_RetriesOn5xx_ThenSucceeds(t *testing.T) {
|
|
var hits atomic.Int32
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
n := hits.Add(1)
|
|
if n < 3 {
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("ok"))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
// Override sleepBackoff: use a transport whose limiter pace makes
|
|
// "exponential" 1s/2s waits unbearable for tests. Inject a small
|
|
// limiter and accept the 1s+2s actual sleep or short-circuit by
|
|
// using maxRetries=0 path. Instead we exercise just 1 retry by
|
|
// returning 502 once: keep test fast.
|
|
hits.Store(0)
|
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
n := hits.Add(1)
|
|
if n == 1 {
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("ok"))
|
|
})
|
|
|
|
c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 1)}
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Do: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("status = %d; want 200", resp.StatusCode)
|
|
}
|
|
if hits.Load() != 2 {
|
|
t.Fatalf("hits = %d; want 2 (initial 502 + one retry)", hits.Load())
|
|
}
|
|
}
|
|
|
|
func TestTransport_Returns429AsErrRateLimited_AfterExhaustingRetries(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Retry-After", "0") // zero seconds => fast test
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 1)}
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
|
_, err := c.Do(req)
|
|
if !errors.Is(err, ErrRateLimited) {
|
|
t.Fatalf("got %v; want ErrRateLimited", err)
|
|
}
|
|
}
|
|
|
|
func TestTransport_InjectsUserAgent(t *testing.T) {
|
|
var seen string
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
seen = r.Header.Get("User-Agent")
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 0)}
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Do: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
if seen != "test-agent" {
|
|
t.Fatalf("UA = %q; want %q", seen, "test-agent")
|
|
}
|
|
}
|
|
|
|
func TestTransport_HonorsRetryAfter(t *testing.T) {
|
|
var hits atomic.Int32
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
n := hits.Add(1)
|
|
if n == 1 {
|
|
w.Header().Set("Retry-After", "0") // 0 = retry immediately
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 2)}
|
|
req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Do: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
if hits.Load() < 2 {
|
|
t.Fatalf("hits = %d; want >= 2", hits.Load())
|
|
}
|
|
if v, _ := strconv.Atoi(resp.Header.Get("Retry-After")); v != 0 && resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("unexpected final state")
|
|
}
|
|
}
|
|
|
|
// traceCtxKey is a test-only context key used to smuggle a sentinel value
|
|
// through a request's context so a slog handler can prove it received it.
|
|
type traceCtxKey struct{}
|
|
|
|
// captureHandler is a minimal slog.Handler that records the context handed
|
|
// to Handle. A context-aware tracing handler reads the active span from that
|
|
// context; this test stand-in just checks the context arrives at all.
|
|
type captureHandler struct {
|
|
mu sync.Mutex
|
|
ctx context.Context
|
|
handled bool
|
|
}
|
|
|
|
func (h *captureHandler) Enabled(context.Context, slog.Level) bool { return true }
|
|
|
|
func (h *captureHandler) Handle(ctx context.Context, _ slog.Record) error {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.ctx = ctx
|
|
h.handled = true
|
|
return nil
|
|
}
|
|
|
|
func (h *captureHandler) WithAttrs([]slog.Attr) slog.Handler { return h }
|
|
func (h *captureHandler) WithGroup(string) slog.Handler { return h }
|
|
|
|
// TestTransport_LogRequest_PropagatesRequestContext guards SDK issue #24:
|
|
// logRequest must emit its record with InfoContext(req.Context(), …), not
|
|
// Info(…), so a context-aware slog.Handler can recover the caller's active
|
|
// span and nest the HTTP span beneath it. Regressing to Info() would hand the
|
|
// handler context.Background() and orphan every HTTP span.
|
|
func TestTransport_LogRequest_PropagatesRequestContext(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
h := &captureHandler{}
|
|
tr := newTestTransport(http.DefaultTransport, 0)
|
|
tr.logger = slog.New(h)
|
|
|
|
c := &http.Client{Transport: tr}
|
|
ctx := context.WithValue(context.Background(), traceCtxKey{}, "trace-123")
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Do: %v", err)
|
|
}
|
|
resp.Body.Close()
|
|
|
|
h.mu.Lock()
|
|
got, handled := h.ctx, h.handled
|
|
h.mu.Unlock()
|
|
|
|
if !handled {
|
|
t.Fatal("slog handler was never called; logRequest emitted no record")
|
|
}
|
|
if got == nil {
|
|
t.Fatal("slog handler received a nil context")
|
|
}
|
|
if v, _ := got.Value(traceCtxKey{}).(string); v != "trace-123" {
|
|
t.Fatalf("handler context value = %q; want %q the request's context did not reach the slog record (logRequest must use InfoContext)", v, "trace-123")
|
|
}
|
|
}
|
|
|
|
func TestTransport_ContextCancellationPropagates(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Block until the request context is cancelled.
|
|
<-r.Context().Done()
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 0)}
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() {
|
|
time.Sleep(30 * time.Millisecond)
|
|
cancel()
|
|
}()
|
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
|
|
_, err := c.Do(req)
|
|
if err == nil {
|
|
t.Fatal("expected cancellation error")
|
|
}
|
|
}
|