diff --git a/hlog/hlog.go b/hlog/hlog.go index f7c3a5a..37f4336 100644 --- a/hlog/hlog.go +++ b/hlog/hlog.go @@ -171,6 +171,22 @@ func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http. } } +// CustomHeaderHandler adds given header from request's header as a field to +// the context's logger using fieldKey as field key. +func CustomHeaderHandler(fieldKey, header string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if val := r.Header.Get(header); val != "" { + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str(fieldKey, val) + }) + } + next.ServeHTTP(w, r) + }) + } +} + // AccessHandler returns a handler that call f after each request. func AccessHandler(f func(r *http.Request, status, size int, duration time.Duration)) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { diff --git a/hlog/hlog_test.go b/hlog/hlog_test.go index 975ce38..2967eb2 100644 --- a/hlog/hlog_test.go +++ b/hlog/hlog_test.go @@ -182,6 +182,24 @@ func TestRequestIDHandler(t *testing.T) { h.ServeHTTP(httptest.NewRecorder(), r) } +func TestCustomHeaderHandler(t *testing.T) { + out := &bytes.Buffer{} + r := &http.Request{ + Header: http.Header{ + "X-Request-Id": []string{"514bbe5bb5251c92bd07a9846f4a1ab6"}, + }, + } + h := CustomHeaderHandler("reqID", "X-Request-Id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + l := FromRequest(r) + l.Log().Msg("") + })) + h = NewHandler(zerolog.New(out))(h) + h.ServeHTTP(nil, r) + if want, got := `{"reqID":"514bbe5bb5251c92bd07a9846f4a1ab6"}`+"\n", decodeIfBinary(out); want != got { + t.Errorf("Invalid log output, got: %s, want: %s", got, want) + } +} + func TestCombinedHandlers(t *testing.T) { out := &bytes.Buffer{} r := &http.Request{