面试题:实现errGroup


写在前面

实现一个ErrGroup(golang.org/x/sync/errgroup)的功能,实现Go,Wait两个方法即可,尽量不要用系统标准库,如sync.WaitGroup

分析

源码

type Group struct {
	cancel func(error)

	wg sync.WaitGroup

	sem chan token

	errOnce sync.Once
	err     error
}

// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
	g.wg.Wait()
	if g.cancel != nil {
		g.cancel(g.err)
	}
	return g.err
}

// Go calls the given function in a new goroutine.
// It blocks until the new goroutine can be added without the number of
// active goroutines in the group exceeding the configured limit.
//
// The first call to return a non-nil error cancels the group's context, if the
// group was created by calling WithContext. The error will be returned by Wait.
func (g *Group) Go(f func() error) {
	if g.sem != nil {
		g.sem <- token{}
	}

	g.wg.Add(1)
	go func() {
		defer g.done()

		if err := f(); err != nil {
			g.errOnce.Do(func() {
				g.err = err
				if g.cancel != nil {
					g.cancel(g.err)
				}
			})
		}
	}()
}

显然,官方errGroup的的实现利用了WaitGroup,用这个就比较简单了。如果不用WaitGroup呢?

作者实现

type Group interface {
	Go(func() error)
	Wait() error
}

type group struct {
	c []chan error // channel,存放结果
}

func NewGroup() Group {
	return &group{}
}

func (g *group) Go(f func() error) {
	errChan := make(chan error)
	idx := len(g.c)
	g.c = append(g.c, errChan)
	go func(idx int) {
		defer func() {
			if e := recover(); e != nil {
				g.writeErr(idx, errors.New(fmt.Sprintf("%v", e)))
			}
		}()

		if err := f(); err != nil {
			g.writeErr(idx, err)
		} else {
			g.writeErr(idx, nil)
		}
	}(idx)
}

func (g *group) Wait() error {
	var result error
	for i := range g.c {
		err := <-g.c[i]
		if result == nil && err != nil {
			result = err
		}
	}
	return result
}

func (g *group) writeErr(idx int, err error) {
	g.c[idx] <- err
	close(g.c[idx])
}

思路:

  • 每次调用Go函数都生成一个chan error对象,用于记录当前func执行的结果,如果执行成果则将nil写入chan中;
  • 然后在Wait函数中获取chan中的值,遇到第一个err则将结果返回。注意,这里要等每个func执行完成。

测试

  • 然后在Wait函数中获取chan中的值,遇到第一个err则将结果返回。注意,这里要等每个func执行完成。
    func main() {
    	g := NewGroup()
    	var urls = []string{
    		// "http://www.golang.org/",
    		// "http://www.google.com/",
    		"http://www.somestupidname.com/",
    		"https://www.baidu.com",
    		"https://www.baidu.com",
    	}
    	for _, url := range urls {
    		//begin := time.Now()
    		//resp, err := http.Get(url)
    		//if err == nil {
    			resp.Body.Close()
    		//}
    		//fmt.Printf("time cost:%f\n", time.Since(begin).Seconds())
    		//break
    		// Launch a goroutine to fetch the URL.
    		url := url // https://golang.org/doc/faq#closures_and_goroutines
    		g.Go(func() error {
    			// Fetch the URL.
    			resp, err := http.Get(url)
    			if err == nil {
    				resp.Body.Close()
    			}
    			return err
    		})
    	}
    	// Wait for all HTTP fetches to complete.
    	if err := g.Wait(); err == nil {
    		fmt.Println("Successfully fetched all URLs.")
    	} else {
    		fmt.Println("failed. err:%v", err)
    	}
    }

面试过程中,执行测试代码的时候,程序一直卡在Wait函数里面,看了挺长时间,没有找到原因。后面面试官给出提示,说可能是google访问超时,而http.Get函数超时时间过长。将url改成baidu.com之后就没有问题了。

引申问题:http.Get默认超时时间是多少

通过阅读src/net/http/transport.go源码,可以发现默认超时时间是30s

// DefaultTransport is the default implementation of [Transport] and is
// used by [DefaultClient]. It establishes network connections as needed
// and caches them for reuse by subsequent calls. It uses HTTP proxies
// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY
// and NO_PROXY (or the lowercase versions thereof).
var DefaultTransport RoundTripper = &Transport{
	Proxy: ProxyFromEnvironment,
	DialContext: defaultTransportDialContext(&net.Dialer{
		Timeout:   30 * time.Second,
		KeepAlive: 30 * time.Second,
	}),
	ForceAttemptHTTP2:     true,
	MaxIdleConns:          100,
	IdleConnTimeout:       90 * time.Second,
	TLSHandshakeTimeout:   10 * time.Second,
	ExpectContinueTimeout: 1 * time.Second,
}

通过测试代码,发现的确是30s

begin := time.Now()
resp, err := http.Get(url)
if err == nil {
	resp.Body.Close()
}
fmt.Printf("time cost:%f\n", time.Since(begin).Seconds())
   // time cost:30.002446

思考

能不能将group结构图中的c []chan error改成一个chan error呢?因为这个程序的要求是返回第一个报错的error,其他error可以忽略。但是要等所有func都执行完。

好像不太行,因为要等每个func都执行完,「执行完」这个动作必须给一个信号,否则Wait函数不知道什么时候所有func执行完了。

源码中用到了WaitGroup,那能不能自己实现一个WaitGroup呢?

理论上应该是可以的,下面这段代码实现了一个WaitGroup

type WaitGroup struct {
	count int
	m     sync.Mutex
}

func NewWaitGroup() *WaitGroup {
	return &WaitGroup{}
}

func (w *WaitGroup) Add(n int) {
	w.m.Lock()
	defer w.m.Unlock()
	w.count += n
}

func (w *WaitGroup) Done() {
	w.m.Lock()
	defer w.m.Unlock()
	if w.count > 0 {
		w.count--
	}
}

func (w *WaitGroup) Wait() {
	for true {
		if w.count == 0 {
			break
		} else {
			time.Sleep(time.Second)
		}
	}
}

这个WaitGroup用到sync.Mutex,面试过程中如果这样使用的话,需要跟面试官确认一下能不能使用sync.Mutex

源码中有SetLimit函数,设置最大并发量,如果超过这个数据则执行失败。

自定义errGroup中如果也有这个设置,就可以在自定义WaitGroup中加一个带容量的chan来实现,就不需要sync.Mutex了。具体代码如下:

var token = struct{}{}

type WaitGroup struct {
	result chan struct{}
}

func NewWaitGroup(n int) *WaitGroup {
	return &WaitGroup{
		result: make(chan struct{}, n),
	}
}

func (w *WaitGroup) SetLimit(n int) {
	w.result = make(chan struct{}, n)
}

func (w *WaitGroup) Add(n int) {
	for i := 0; i < n; i++ {
		w.result <- token
	}
}

func (w *WaitGroup) Done() {
	if len(w.result) > 0 {
		<-w.result
	}
}

func (w *WaitGroup) Wait() {
	for true {
		if len(w.result) == 0 {
			close(w.result)
			break
		} else {
			time.Sleep(time.Second)
		}
	}
}

上面两种WaitGroup实现作者没有测试过,说明思想。读者有什么想法欢迎留言讨论。


文章作者: Alex
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Alex !
  目录