diff --git a/ctx.go b/ctx.go index 04d8a11..2b7a682 100644 --- a/ctx.go +++ b/ctx.go @@ -2,38 +2,39 @@ package zerolog import ( "context" - "io/ioutil" ) var disabledLogger *Logger func init() { - l := New(ioutil.Discard).Level(Disabled) + l := Nop() disabledLogger = &l } type ctxKey struct{} // WithContext returns a copy of ctx with l associated. If an instance of Logger -// is already in the context, the pointer to this logger is updated with l. +// is already in the context, the context is not updated. // // For instance, to add a field to an existing logger in the context, use this // notation: // // ctx := r.Context() // l := zerolog.Ctx(ctx) -// ctx = l.With().Str("foo", "bar").WithContext(ctx) -func (l Logger) WithContext(ctx context.Context) context.Context { +// l.UpdateContext(func(c Context) Context { +// return c.Str("bar", "baz") +// }) +func (l *Logger) WithContext(ctx context.Context) context.Context { if lp, ok := ctx.Value(ctxKey{}).(*Logger); ok { - // Update existing pointer. - *lp = l - return ctx - } - if l.level == Disabled { + if lp == l { + // Do not store same logger. + return ctx + } + } else if l.level == Disabled { // Do not store disabled logger. return ctx } - return context.WithValue(ctx, ctxKey{}, &l) + return context.WithValue(ctx, ctxKey{}, l) } // Ctx returns the Logger associated with the ctx. If no logger diff --git a/ctx_test.go b/ctx_test.go index 942b723..646cd0f 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -30,18 +30,34 @@ func TestCtx(t *testing.T) { } func TestCtxDisabled(t *testing.T) { - ctx := disabledLogger.WithContext(context.Background()) + dl := New(ioutil.Discard).Level(Disabled) + ctx := dl.WithContext(context.Background()) if ctx != context.Background() { t.Error("WithContext stored a disabled logger") } - ctx = New(ioutil.Discard).WithContext(ctx) - if reflect.DeepEqual(Ctx(ctx), disabledLogger) { + l := New(ioutil.Discard).With().Str("foo", "bar").Logger() + ctx = l.WithContext(ctx) + if Ctx(ctx) != &l { t.Error("WithContext did not store logger") } - ctx = disabledLogger.WithContext(ctx) - if !reflect.DeepEqual(Ctx(ctx), disabledLogger) { - t.Error("WithContext did not update logger pointer with disabled logger") + l.UpdateContext(func(c Context) Context { + return c.Str("bar", "baz") + }) + ctx = l.WithContext(ctx) + if Ctx(ctx) != &l { + t.Error("WithContext did not store updated logger") + } + + l = l.Level(DebugLevel) + ctx = l.WithContext(ctx) + if Ctx(ctx) != &l { + t.Error("WithContext did not store copied logger") + } + + ctx = dl.WithContext(ctx) + if Ctx(ctx) != &dl { + t.Error("WithContext did not overide logger with a disabled logger") } }