Skip to content

Commit b2af013

Browse files
author
Mat Ryer
committed
added StartStopper
1 parent 3c5cf52 commit b2af013

File tree

2 files changed

+167
-0
lines changed

2 files changed

+167
-0
lines changed

start/start.go

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package start
2+
3+
import (
4+
"sync"
5+
"time"
6+
7+
"github.com/stretchr/pat/stop"
8+
)
9+
10+
// StartStopper represents types that need starting and stopping.
11+
type StartStopper interface {
12+
stop.Stopper
13+
// Start is called to start the operations of the
14+
// type, an error is returned if starting fails, otherwise
15+
// nil.
16+
Start() error
17+
}
18+
19+
// All starts all StartStoppers and returns a map of any
20+
// errors that occurred keyed by the StartStopper object.
21+
// len(All(s...)) == 0 when everything started successfully.
22+
func All(startStoppers ...StartStopper) map[StartStopper]error {
23+
24+
var wg sync.WaitGroup
25+
errs := make(map[StartStopper]error)
26+
var errsL sync.Mutex
27+
wg.Add(len(startStoppers))
28+
29+
// start everything
30+
for _, s := range startStoppers {
31+
go func(s StartStopper) {
32+
if err := s.Start(); err != nil {
33+
errsL.Lock()
34+
errs[s] = err
35+
errsL.Unlock()
36+
}
37+
wg.Done()
38+
}(s)
39+
}
40+
41+
// wait for all things to start
42+
wg.Wait()
43+
44+
return errs
45+
}
46+
47+
// MustAll starts all StartStoppers and if any one fails to start,
48+
// stops all of the others and returns a map of errors that occurred, keyed
49+
// by the StartStopper object.
50+
// len(All(s...)) == 0 when everything started successfully.
51+
// By the time MustAll returns, everything will have been given the chance to
52+
// start, and in the event of an error, everything will be properly stopped.
53+
func MustAll(stopGrace time.Duration, startStoppers ...StartStopper) map[StartStopper]error {
54+
errs := All(startStoppers...)
55+
if len(errs) > 0 {
56+
// at least one thing failed to start - stop everything that
57+
// started
58+
var wg sync.WaitGroup
59+
for _, s := range startStoppers {
60+
if _, failed := errs[s]; failed == false {
61+
s.Stop(stopGrace)
62+
wg.Add(1)
63+
go func() {
64+
<-s.StopChan()
65+
wg.Done()
66+
}()
67+
}
68+
}
69+
wg.Wait() // wait for everything to stop
70+
}
71+
return errs
72+
}

start/start_test.go

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package start_test
2+
3+
import (
4+
"errors"
5+
"testing"
6+
"time"
7+
8+
"github.com/stretchr/pat/start"
9+
"github.com/stretchr/pat/stop"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
type TestStarter struct {
14+
running bool
15+
}
16+
17+
var _ start.StartStopper = (*TestStarter)(nil)
18+
19+
func (t *TestStarter) Start() error {
20+
time.Sleep(1 * time.Second)
21+
t.running = true
22+
return nil
23+
}
24+
func (t *TestStarter) Stop(time.Duration) {
25+
t.running = false
26+
}
27+
func (t *TestStarter) StopChan() <-chan stop.Signal {
28+
return stop.Stopped()
29+
}
30+
31+
type ErrorStarter struct {
32+
running bool
33+
}
34+
35+
var _ start.StartStopper = (*ErrorStarter)(nil)
36+
37+
func (t *ErrorStarter) Start() error {
38+
time.Sleep(1 * time.Second)
39+
return errors.New("something went wrong")
40+
}
41+
func (t *ErrorStarter) Stop(time.Duration) {
42+
t.running = false
43+
}
44+
func (t *ErrorStarter) StopChan() <-chan stop.Signal {
45+
return stop.Stopped()
46+
}
47+
48+
func TestAll(t *testing.T) {
49+
50+
s1 := &TestStarter{}
51+
s2 := &TestStarter{}
52+
s3 := &TestStarter{}
53+
54+
errs := start.All(s1, s2, s3)
55+
require.Equal(t, 0, len(errs))
56+
57+
require.True(t, s1.running)
58+
require.True(t, s2.running)
59+
require.True(t, s3.running)
60+
61+
}
62+
63+
func TestAllErr(t *testing.T) {
64+
65+
s1 := &TestStarter{}
66+
s2 := &TestStarter{}
67+
s3 := &ErrorStarter{}
68+
69+
errs := start.All(s1, s2, s3)
70+
require.Equal(t, 1, len(errs))
71+
72+
require.Equal(t, errs[s3].Error(), "something went wrong")
73+
74+
require.True(t, s1.running)
75+
require.True(t, s2.running)
76+
require.False(t, s3.running)
77+
78+
}
79+
80+
func TestMustAll(t *testing.T) {
81+
82+
s1 := &TestStarter{}
83+
s2 := &TestStarter{}
84+
s3 := &ErrorStarter{}
85+
86+
errs := start.MustAll(500*time.Millisecond, s1, s2, s3)
87+
require.Equal(t, 1, len(errs))
88+
89+
require.Equal(t, errs[s3].Error(), "something went wrong")
90+
91+
require.False(t, s1.running)
92+
require.False(t, s2.running)
93+
require.False(t, s3.running)
94+
95+
}

0 commit comments

Comments
 (0)