// Copyright (C) 2022 The OKit Authors // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, // DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR // OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE // OR OTHER DEALINGS IN THE SOFTWARE. package http import ( "context" "net/http" "go.pitz.tech/okit" ) const ( // OkitTraceIDHeader is a string constant that defines a common key used to propagate an okit trace. OkitTraceIDHeader = "x-okit-trace-id" // OkitSpanIDHeader is a string constant that defines a common key used to propagate the parent span. OkitSpanIDHeader = "x-okit-span-id" ) // InstrumentClient updates the provided http.Client to use an instrumented http.RoundTripper. If the provided // client.Transport is nil, then http.DefaultTransport is used. func InstrumentClient(client *http.Client) { if client.Transport == nil { client.Transport = http.DefaultTransport } client.Transport = InstrumentRoundTripper(client.Transport) } // InstrumentRoundTripper wraps the provided http.RoundTripper with an implementation that traces the remote request. func InstrumentRoundTripper(rt http.RoundTripper) http.RoundTripper { return &roundTripper{rt} } type roundTripper struct { delegate http.RoundTripper } func (r *roundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) { ctx := req.Context() statusCode := 0 tags := []okit.Tag{ okit.String("host", req.URL.Host), okit.String("method", req.Method), okit.String("path", req.URL.Path), okit.Intp("status", &statusCode), } defer okit.Trace(&ctx, tags...).Done() // propagate to server // TODO: support spec compliant propagation if v := ctx.Value(okit.SpanKey); v != nil { if span, ok := v.(okit.Span); ok { req.Header.Set(OkitTraceIDHeader, span.TraceID) req.Header.Set(OkitSpanIDHeader, span.ID) } } resp, err = r.delegate.RoundTrip(req.WithContext(ctx)) if err != nil { return nil, err } statusCode = resp.StatusCode return resp, nil } var _ http.RoundTripper = &roundTripper{} // InstrumentHandler wraps the provided http.Handler with an implementation that traces HTTP requests and attaches // relevant tags to the underlying span. Using the ctx, we can resolve the last span. func InstrumentHandler(delegate http.Handler) http.Handler { return &handler{delegate} } type handler struct { delegate http.Handler } func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // propagation from client traceID := r.Header.Get(OkitTraceIDHeader) spanID := r.Header.Get(OkitSpanIDHeader) if traceID != "" && spanID != "" { ctx = context.WithValue(ctx, okit.SpanKey, okit.Span{ TraceID: traceID, ID: spanID, }) } statusCode := 200 tags := []okit.Tag{ okit.String("host", r.URL.Host), okit.String("method", r.Method), okit.String("path", r.URL.Path), okit.Intp("status", &statusCode), } defer okit.Trace(&ctx, tags...).Done() h.delegate.ServeHTTP( &response{ delegate: w, statusCode: &statusCode, }, r.WithContext(ctx), ) } var _ http.Handler = &handler{} type response struct { delegate http.ResponseWriter statusCode *int } func (r *response) Header() http.Header { return r.delegate.Header() } func (r *response) Write(bytes []byte) (int, error) { return r.delegate.Write(bytes) } func (r *response) WriteHeader(statusCode int) { *(r.statusCode) = statusCode r.delegate.WriteHeader(statusCode) } var _ http.ResponseWriter = &response{}