写在前面
实现一个ErrGroup(golang.org/x/sync/errgroup)的功能,实现Go,Wait两个方法即可,尽量不要用系统标准库,如sync.WaitGroup
分析
源码
type Group struct {
cancel func(error)
wg sync.WaitGroup
sem chan token
errOnce sync.Once
err error
}
// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel(g.err)
}
return g.err
}
// Go calls the given function in a new goroutine.
// It blocks until the new goroutine can be added without the number of
// active goroutines in the group exceeding the configured limit.
//
// The first call to return a non-nil error cancels the group's context, if the
// group was created by calling WithContext. The error will be returned by Wait.
func (g *Group) Go(f func() error) {
if g.sem != nil {
g.sem <- token{}
}
g.wg.Add(1)
go func() {
defer g.done()
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel(g.err)
}
})
}
}()
}
显然,官方errGroup的的实现利用了WaitGroup,用这个就比较简单了。如果不用WaitGroup呢?
作者实现
type Group interface {
Go(func() error)
Wait() error
}
type group struct {
c []chan error // channel,存放结果
}
func NewGroup() Group {
return &group{}
}
func (g *group) Go(f func() error) {
errChan := make(chan error)
idx := len(g.c)
g.c = append(g.c, errChan)
go func(idx int) {
defer func() {
if e := recover(); e != nil {
g.writeErr(idx, errors.New(fmt.Sprintf("%v", e)))
}
}()
if err := f(); err != nil {
g.writeErr(idx, err)
} else {
g.writeErr(idx, nil)
}
}(idx)
}
func (g *group) Wait() error {
var result error
for i := range g.c {
err := <-g.c[i]
if result == nil && err != nil {
result = err
}
}
return result
}
func (g *group) writeErr(idx int, err error) {
g.c[idx] <- err
close(g.c[idx])
}
思路:
- 每次调用Go函数都生成一个chan error对象,用于记录当前func执行的结果,如果执行成果则将nil写入chan中;
- 然后在Wait函数中获取chan中的值,遇到第一个err则将结果返回。注意,这里要等每个func执行完成。
测试
- 然后在Wait函数中获取chan中的值,遇到第一个err则将结果返回。注意,这里要等每个func执行完成。
func main() { g := NewGroup() var urls = []string{ // "http://www.golang.org/", // "http://www.google.com/", "http://www.somestupidname.com/", "https://www.baidu.com", "https://www.baidu.com", } for _, url := range urls { //begin := time.Now() //resp, err := http.Get(url) //if err == nil { resp.Body.Close() //} //fmt.Printf("time cost:%f\n", time.Since(begin).Seconds()) //break // Launch a goroutine to fetch the URL. url := url // https://golang.org/doc/faq#closures_and_goroutines g.Go(func() error { // Fetch the URL. resp, err := http.Get(url) if err == nil { resp.Body.Close() } return err }) } // Wait for all HTTP fetches to complete. if err := g.Wait(); err == nil { fmt.Println("Successfully fetched all URLs.") } else { fmt.Println("failed. err:%v", err) } }
面试过程中,执行测试代码的时候,程序一直卡在Wait函数里面,看了挺长时间,没有找到原因。后面面试官给出提示,说可能是google访问超时,而http.Get函数超时时间过长。将url改成baidu.com之后就没有问题了。
引申问题:http.Get默认超时时间是多少
通过阅读src/net/http/transport.go源码,可以发现默认超时时间是30s
// DefaultTransport is the default implementation of [Transport] and is
// used by [DefaultClient]. It establishes network connections as needed
// and caches them for reuse by subsequent calls. It uses HTTP proxies
// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY
// and NO_PROXY (or the lowercase versions thereof).
var DefaultTransport RoundTripper = &Transport{
Proxy: ProxyFromEnvironment,
DialContext: defaultTransportDialContext(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}),
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
通过测试代码,发现的确是30s
begin := time.Now()
resp, err := http.Get(url)
if err == nil {
resp.Body.Close()
}
fmt.Printf("time cost:%f\n", time.Since(begin).Seconds())
// time cost:30.002446
思考
能不能将group结构图中的c []chan error改成一个chan error呢?因为这个程序的要求是返回第一个报错的error,其他error可以忽略。但是要等所有func都执行完。
好像不太行,因为要等每个func都执行完,「执行完」这个动作必须给一个信号,否则Wait函数不知道什么时候所有func执行完了。
源码中用到了WaitGroup,那能不能自己实现一个WaitGroup呢?
理论上应该是可以的,下面这段代码实现了一个WaitGroup
type WaitGroup struct {
count int
m sync.Mutex
}
func NewWaitGroup() *WaitGroup {
return &WaitGroup{}
}
func (w *WaitGroup) Add(n int) {
w.m.Lock()
defer w.m.Unlock()
w.count += n
}
func (w *WaitGroup) Done() {
w.m.Lock()
defer w.m.Unlock()
if w.count > 0 {
w.count--
}
}
func (w *WaitGroup) Wait() {
for true {
if w.count == 0 {
break
} else {
time.Sleep(time.Second)
}
}
}
这个WaitGroup用到sync.Mutex,面试过程中如果这样使用的话,需要跟面试官确认一下能不能使用sync.Mutex
源码中有SetLimit函数,设置最大并发量,如果超过这个数据则执行失败。
自定义errGroup中如果也有这个设置,就可以在自定义WaitGroup中加一个带容量的chan来实现,就不需要sync.Mutex了。具体代码如下:
var token = struct{}{}
type WaitGroup struct {
result chan struct{}
}
func NewWaitGroup(n int) *WaitGroup {
return &WaitGroup{
result: make(chan struct{}, n),
}
}
func (w *WaitGroup) SetLimit(n int) {
w.result = make(chan struct{}, n)
}
func (w *WaitGroup) Add(n int) {
for i := 0; i < n; i++ {
w.result <- token
}
}
func (w *WaitGroup) Done() {
if len(w.result) > 0 {
<-w.result
}
}
func (w *WaitGroup) Wait() {
for true {
if len(w.result) == 0 {
close(w.result)
break
} else {
time.Sleep(time.Second)
}
}
}
上面两种WaitGroup实现作者没有测试过,说明思想。读者有什么想法欢迎留言讨论。