From 7a3aa8746f6f9f744220cf29e9f03d6d2b64cf06 Mon Sep 17 00:00:00 2001 From: Dan Gillis Date: Tue, 2 Mar 2021 05:01:32 -0500 Subject: [PATCH] Allow setting context using idKey. Assist for issue #293. (#296) --- .travis.yml | 3 +-- hlog/hlog.go | 7 ++++++- hlog/hlog_test.go | 21 +++++++++++++++++---- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/.travis.yml b/.travis.yml index 061f265..a5d00e2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,5 @@ language: go go: -- "1.7" -- "1.8" - "1.9" - "1.10" - "1.11" @@ -9,6 +7,7 @@ go: - "1.13" - "1.14" - "1.15" +- "1.16" - "master" matrix: allow_failures: diff --git a/hlog/hlog.go b/hlog/hlog.go index 8748b25..d9a6095 100644 --- a/hlog/hlog.go +++ b/hlog/hlog.go @@ -138,6 +138,11 @@ func IDFromCtx(ctx context.Context) (id xid.ID, ok bool) { return } +// CtxWithID adds the given xid.ID to the context +func CtxWithID(ctx context.Context, id xid.ID) context.Context { + return context.WithValue(ctx, idKey{}, id) +} + // RequestIDHandler returns a handler setting a unique id to the request which can // be gathered using IDFromRequest(req). This generated id is added as a field to the // logger using the passed fieldKey as field name. The id is also added as a response @@ -154,7 +159,7 @@ func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http. id, ok := IDFromRequest(r) if !ok { id = xid.New() - ctx = context.WithValue(ctx, idKey{}, id) + ctx = CtxWithID(ctx, id) r = r.WithContext(ctx) } if fieldKey != "" { diff --git a/hlog/hlog_test.go b/hlog/hlog_test.go index 2967eb2..b1c24fe 100644 --- a/hlog/hlog_test.go +++ b/hlog/hlog_test.go @@ -4,16 +4,16 @@ package hlog import ( "bytes" + "context" "fmt" "io/ioutil" "net/http" + "net/http/httptest" "net/url" + "reflect" "testing" - "reflect" - - "net/http/httptest" - + "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/internal/cbor" ) @@ -262,3 +262,16 @@ func BenchmarkDataRace(b *testing.B) { } }) } + +func TestCtxWithID(t *testing.T) { + ctx := context.Background() + + id, _ := xid.FromString(`c0umremcie6smuu506pg`) + + want := context.Background() + want = context.WithValue(want, idKey{}, id) + + if got := CtxWithID(ctx, id); !reflect.DeepEqual(got, want) { + t.Errorf("CtxWithID() = %v, want %v", got, want) + } +}