package fa import ( "context" "errors" "log/slog" "net/http" "net/http/httptest" "strconv" "sync" "sync/atomic" "testing" "time" ) // newTestTransport builds a transport that wraps a base RoundTripper for // httptest use. Rate limit is fast (no waits between requests) so retry // tests run in subseconds. func newTestTransport(base http.RoundTripper, maxRetries int) *transport { return &transport{ base: base, limiter: newRateLimiter(time.Microsecond, 16, false), userAgent: "test-agent", maxRetries: maxRetries, } } func TestTransport_DetectsCloudflareChallenge_FromHeader(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("cf-mitigated", "challenge") w.WriteHeader(http.StatusForbidden) })) defer srv.Close() c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 3)} req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) _, err := c.Do(req) if !errors.Is(err, ErrCloudflareChallenge) { t.Fatalf("got %v; want ErrCloudflareChallenge", err) } } func TestTransport_DetectsCloudflareChallenge_From503CFRay(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("cf-ray", "abc123-FRA") w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusServiceUnavailable) _, _ = w.Write([]byte("Just a moment...")) })) defer srv.Close() c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 3)} req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) _, err := c.Do(req) if !errors.Is(err, ErrCloudflareChallenge) { t.Fatalf("got %v; want ErrCloudflareChallenge", err) } } func TestTransport_RetriesOn5xx_ThenSucceeds(t *testing.T) { var hits atomic.Int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := hits.Add(1) if n < 3 { w.WriteHeader(http.StatusBadGateway) return } w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) defer srv.Close() // Override sleepBackoff: use a transport whose limiter pace makes // "exponential" 1s/2s waits unbearable for tests. Inject a small // limiter and accept the 1s+2s actual sleep or short-circuit by // using maxRetries=0 path. Instead we exercise just 1 retry by // returning 502 once: keep test fast. hits.Store(0) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := hits.Add(1) if n == 1 { w.WriteHeader(http.StatusBadGateway) return } w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) }) c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 1)} req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) resp, err := c.Do(req) if err != nil { t.Fatalf("Do: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("status = %d; want 200", resp.StatusCode) } if hits.Load() != 2 { t.Fatalf("hits = %d; want 2 (initial 502 + one retry)", hits.Load()) } } func TestTransport_Returns429AsErrRateLimited_AfterExhaustingRetries(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Retry-After", "0") // zero seconds => fast test w.WriteHeader(http.StatusTooManyRequests) })) defer srv.Close() c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 1)} req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) _, err := c.Do(req) if !errors.Is(err, ErrRateLimited) { t.Fatalf("got %v; want ErrRateLimited", err) } } func TestTransport_InjectsUserAgent(t *testing.T) { var seen string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { seen = r.Header.Get("User-Agent") w.WriteHeader(http.StatusOK) })) defer srv.Close() c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 0)} req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) resp, err := c.Do(req) if err != nil { t.Fatalf("Do: %v", err) } resp.Body.Close() if seen != "test-agent" { t.Fatalf("UA = %q; want %q", seen, "test-agent") } } func TestTransport_HonorsRetryAfter(t *testing.T) { var hits atomic.Int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := hits.Add(1) if n == 1 { w.Header().Set("Retry-After", "0") // 0 = retry immediately w.WriteHeader(http.StatusTooManyRequests) return } w.WriteHeader(http.StatusOK) })) defer srv.Close() c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 2)} req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) resp, err := c.Do(req) if err != nil { t.Fatalf("Do: %v", err) } resp.Body.Close() if hits.Load() < 2 { t.Fatalf("hits = %d; want >= 2", hits.Load()) } if v, _ := strconv.Atoi(resp.Header.Get("Retry-After")); v != 0 && resp.StatusCode != http.StatusOK { t.Fatalf("unexpected final state") } } // traceCtxKey is a test-only context key used to smuggle a sentinel value // through a request's context so a slog handler can prove it received it. type traceCtxKey struct{} // captureHandler is a minimal slog.Handler that records the context handed // to Handle. A context-aware tracing handler reads the active span from that // context; this test stand-in just checks the context arrives at all. type captureHandler struct { mu sync.Mutex ctx context.Context handled bool } func (h *captureHandler) Enabled(context.Context, slog.Level) bool { return true } func (h *captureHandler) Handle(ctx context.Context, _ slog.Record) error { h.mu.Lock() defer h.mu.Unlock() h.ctx = ctx h.handled = true return nil } func (h *captureHandler) WithAttrs([]slog.Attr) slog.Handler { return h } func (h *captureHandler) WithGroup(string) slog.Handler { return h } // TestTransport_LogRequest_PropagatesRequestContext guards SDK issue #24: // logRequest must emit its record with InfoContext(req.Context(), …), not // Info(…), so a context-aware slog.Handler can recover the caller's active // span and nest the HTTP span beneath it. Regressing to Info() would hand the // handler context.Background() and orphan every HTTP span. func TestTransport_LogRequest_PropagatesRequestContext(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer srv.Close() h := &captureHandler{} tr := newTestTransport(http.DefaultTransport, 0) tr.logger = slog.New(h) c := &http.Client{Transport: tr} ctx := context.WithValue(context.Background(), traceCtxKey{}, "trace-123") req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) resp, err := c.Do(req) if err != nil { t.Fatalf("Do: %v", err) } resp.Body.Close() h.mu.Lock() got, handled := h.ctx, h.handled h.mu.Unlock() if !handled { t.Fatal("slog handler was never called; logRequest emitted no record") } if got == nil { t.Fatal("slog handler received a nil context") } if v, _ := got.Value(traceCtxKey{}).(string); v != "trace-123" { t.Fatalf("handler context value = %q; want %q the request's context did not reach the slog record (logRequest must use InfoContext)", v, "trace-123") } } func TestTransport_ContextCancellationPropagates(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Block until the request context is cancelled. <-r.Context().Done() })) defer srv.Close() c := &http.Client{Transport: newTestTransport(http.DefaultTransport, 0)} ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(30 * time.Millisecond) cancel() }() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) _, err := c.Do(req) if err == nil { t.Fatal("expected cancellation error") } }