diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go index 97a1aa4..7846184 100644 --- a/singleflight/singleflight.go +++ b/singleflight/singleflight.go @@ -6,7 +6,10 @@ // mechanism. package singleflight // import "golang.org/x/sync/singleflight" -import "sync" +import ( + "sync" + "sync/atomic" +) // call is an in-flight or completed singleflight.Do call type call struct { @@ -24,7 +27,7 @@ type call struct { // These fields are read and written with the singleflight // mutex held before the WaitGroup is done, and are read but // not written after the WaitGroup is done. - dups int + dups int64 chans []chan<- Result } @@ -49,6 +52,39 @@ type Result struct { // original to complete and receives the same results. // The return value shared indicates whether v was given to multiple callers. func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) { + c := g.doNoChan(key, fn) + return c.val, c.err, c.dups > 0 +} + +// 'Use' calls 'new' at most once at a time, then invokes 'fn' with the resulting values. +// The 'dispose' argument invokes after the last call to fn has returned. +// +// `Use` is designed for scenario when 'new' generates a temporary resource, which has to be cleaned up after last 'fn' is done using it +// Notes: +// 'dispose' is called at most once, after last fn been completed. 'dispose' will NOT get called if/when 'new' returns an error +// 'fn' is called on each goroutine with values returned by 'new', regardless of whether or not 'new' returned an error +// results of 'new' are passed to 'fn'. +// +// Return: 'Use' propagates return value from 'fn' +func (g *Group) Use( + key string, + new func() (interface{}, error), + fn func(interface{}, error) error, + dispose func(interface{}), +) error { + c := g.doNoChan(key, new) + if c.err == nil && dispose != nil { + defer func() { + if atomic.AddInt64(&c.dups, -1) == -1 { + dispose(c.val) + } + }() + } + + return fn(c.val, c.err) +} + +func (g *Group) doNoChan(key string, fn func() (interface{}, error)) *call { g.mu.Lock() if g.m == nil { g.m = make(map[string]*call) @@ -57,7 +93,7 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, e c.dups++ g.mu.Unlock() c.wg.Wait() - return c.val, c.err, true + return c } c := new(call) c.wg.Add(1) @@ -65,7 +101,7 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, e g.mu.Unlock() g.doCall(c, key, fn) - return c.val, c.err, c.dups > 0 + return c } // DoChan is like Do but returns a channel that will receive the diff --git a/singleflight/singleflight_test.go b/singleflight/singleflight_test.go index ad04037..ad5ad95 100644 --- a/singleflight/singleflight_test.go +++ b/singleflight/singleflight_test.go @@ -7,12 +7,130 @@ package singleflight import ( "errors" "fmt" + "io/ioutil" + "os" "sync" "sync/atomic" "testing" "time" ) +func testConcurrentHelper(t *testing.T, inGoroutine func(routineIndex, goroutineCount int)) { + var wg, wgGoroutines sync.WaitGroup + const callers = 4 + //ref := make([]RefCounter, callers) + wgGoroutines.Add(callers) + for i := 0; i < callers; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + + wgGoroutines.Done() + wgGoroutines.Wait() // ensure that all goroutines started and reached this point + + inGoroutine(i, callers) + }(i) + } + wg.Wait() + +} + +func TestUse(t *testing.T) { + var g Group + var newCount, handleCount, disposeCount int64 + + testConcurrentHelper( + t, + func(index, goroutineCount int) { + g.Use( + "key", + // 'new' is a slow function that creates a temp resource + func() (interface{}, error) { + time.Sleep(200 * time.Millisecond) // let more goroutines enter Do + atomic.AddInt64(&newCount, 1) + return "bar", nil + }, + // 'fn' to be called by each goroutine + func(s interface{}, e error) error { + // handle s + if newCount != 1 { + t.Errorf("goroutine %v: newCount(%v) expected to be set prior to this function getting called", index, newCount) + } + atomic.AddInt64(&handleCount, 1) + if disposeCount > 0 { + t.Errorf("goroutine %v: disposeCount(%v) should not be incremented until all fn are completed", index, disposeCount) + } + return e + }, + // 'dispose' - to be called once at the end + func(s interface{}) { + // cleaning up "bar" + atomic.AddInt64(&disposeCount, 1) + if handleCount != int64(goroutineCount) { + t.Errorf("dispose is expected to be called when all %v fn been completed, but %v have been completed instead", goroutineCount, handleCount) + } + }, + ) + }, + ) + + if newCount != 1 { + t.Errorf("new expected to be called exactly once, was called %v", newCount) + } + if disposeCount != 1 { + t.Errorf("dispose expected to be called exactly once, was called %v", disposeCount) + } +} + +func TestUseWithResource(t *testing.T) { + // use this "global" var for checkes after that testConcurrentHelper call + var tempFileName string + + var g Group + testConcurrentHelper( + t, + func(_, _ int) { + g.Use( + "key", + // 'new' is a slow function that creates a temp resource + func() (interface{}, error) { + time.Sleep(200 * time.Millisecond) // let more goroutines enter Do + f, e := ioutil.TempFile("", "pat") + if e != nil { + return nil, e + } + defer f.Close() + tempFileName = f.Name() + + // fill temp file with sequence of n.Write(...) calls + + return f.Name(), e + }, + // 'fn' to be called by each goroutine + func(s interface{}, e error) error { + // handle s + if e != nil { + // send alternative payload + } + if e == nil { + /*tempFileName*/ _ = s.(string) + // send Content of tempFileName to HTTPWriter + } + return e + }, + // 'dispose' - to be called once at the end + func(s interface{}) { + // cleaning up "bar" + os.Remove(s.(string)) + }, + ) + }, + ) + if _, e := os.Stat(tempFileName); !os.IsNotExist(e) { + t.Errorf("test has created a temp file '%v', but failed to cleaned it", tempFileName) + } +} + func TestDo(t *testing.T) { var g Group v, err, _ := g.Do("key", func() (interface{}, error) {