diff --git a/mian.go b/main.go similarity index 92% rename from mian.go rename to main.go index c53ea2e..28763ed 100644 --- a/mian.go +++ b/main.go @@ -13,7 +13,7 @@ func main() { wp.Do(func() error { for j := 0; j < 5; j++ { //每次打印0-10的值 //fmt.Println(fmt.Sprintf("%v->\t%v", ii, j)) - time.Sleep(1 * time.Second) + time.Sleep(1 * time.Microsecond) } return nil }) @@ -22,7 +22,6 @@ func main() { } wp.Wait() - fmt.Println(wp.IsDone()) fmt.Println("down") diff --git a/workerpool/def.go b/workerpool/def.go index 5e87d8e..09630da 100644 --- a/workerpool/def.go +++ b/workerpool/def.go @@ -5,15 +5,9 @@ import ( "time" ) -// // CallHandler process .定义调用回调体(可修改) -// type CallHandler func() - // TaskHandler process .定义函数回调体 type TaskHandler func() error -// ServeHandler must process tls.Config.NextProto negotiated requests. -//type ServeHandler func(c net.Conn) error - // workerPool serves incoming connections via a pool of workers // in FILO order, i.e. the most recently stopped worker will serve the next // incoming connection. @@ -21,11 +15,11 @@ type TaskHandler func() error // Such a scheme keeps CPU caches hot (in theory). type WorkerPool struct { //sync.Mutex - maxWorkersCount int //最大的工作协程数 - closed int32 - errChan chan error //错误chan - timeout time.Duration //最大超时时间 - wg sync.WaitGroup - task chan TaskHandler - start sync.Once + //maxWorkersCount int //最大的工作协程数 + //start sync.Once + closed int32 + errChan chan error //错误chan + timeout time.Duration //最大超时时间 + wg sync.WaitGroup + task chan TaskHandler } diff --git a/workerpool/workerpool.go b/workerpool/workerpool.go index e070c96..ca1f387 100644 --- a/workerpool/workerpool.go +++ b/workerpool/workerpool.go @@ -15,11 +15,14 @@ func New(max int) *WorkerPool { max = 1 } - return &WorkerPool{ - maxWorkersCount: max, - task: make(chan TaskHandler, max), - errChan: make(chan error, 1), + p := &WorkerPool{ + task: make(chan TaskHandler, 2*max), + errChan: make(chan error, 1), } + + go p.loop(max) + + return p } //SetTimeout 设置超时时间 @@ -27,22 +30,9 @@ func (p *WorkerPool) SetTimeout(timeout time.Duration) { p.timeout = timeout } -//SingleCall 单程执行(排他) -// func (p *WorkerPool) SingleCall(fn TaskHandler) { -// p.Mutex.Lock() -// fn() -// p.Mutex.Unlock() -// } - //Do 添加到工作池,并立即返回 func (p *WorkerPool) Do(fn TaskHandler) { - p.start.Do(func() { //once - p.wg.Add(p.maxWorkersCount) - go p.loop() - }) - - if atomic.LoadInt32(&p.closed) == 1 { - // 已关闭 + if p.IsClosed() { // 已关闭 return } p.task <- fn @@ -50,27 +40,22 @@ func (p *WorkerPool) Do(fn TaskHandler) { //DoWait 添加到工作池,并等待执行完成之后再返回 func (p *WorkerPool) DoWait(task TaskHandler) { - p.start.Do(func() { //once - p.wg.Add(p.maxWorkersCount) - go p.loop() - }) - - if atomic.LoadInt32(&p.closed) == 1 { // 已关闭 + if p.IsClosed() { // 已关闭 return } doneChan := make(chan struct{}) p.task <- func() error { - err := task() - close(doneChan) - return err + defer close(doneChan) + return task() } <-doneChan } -func (p *WorkerPool) loop() { - // 启动n个worker - for i := 0; i < p.maxWorkersCount; i++ { +func (p *WorkerPool) loop(maxWorkersCount int) { + p.wg.Add(maxWorkersCount) // 最大的工作协程数 + // 启动max个worker + for i := 0; i < maxWorkersCount; i++ { go func() { defer p.wg.Done() // worker 开始干活 @@ -132,3 +117,11 @@ func (p *WorkerPool) IsDone() bool { return len(p.task) == 0 } + +//IsClosed 是否已经关闭 +func (p *WorkerPool) IsClosed() bool { + if atomic.LoadInt32(&p.closed) == 1 { // 已关闭 + return true + } + return false +} diff --git a/workerpool/workerpool_test.go b/workerpool/workerpool_test.go index b698768..6d47a67 100644 --- a/workerpool/workerpool_test.go +++ b/workerpool/workerpool_test.go @@ -52,33 +52,6 @@ func TestWorkerPoolError(t *testing.T) { fmt.Println("down") } -//测试排他执行(到单线程模式) -// func TestWorkerPoolSingleCall(t *testing.T) { -// wp := New(2) //设置最大线程数 -// for i := 0; i < 4; i++ { //开启20个请求 -// ii := i -// wp.SingleCall(func() error { -// for j := 0; j < 2; j++ { //每次打印0-10的值 -// fmt.Println(fmt.Sprintf("%v->\t%v", ii, j)) -// if ii == 1 { -// return errors.New("my test err") -// } -// time.Sleep(1 * time.Second) -// } - -// return nil -// //time.Sleep(1 * time.Second) -// //return errors.New("my test err") -// }) -// } - -// err := wp.Wait() -// if err != nil { -// fmt.Println(err) -// } -// fmt.Println("down") -// } - //放到工作池里面 且等待执行结果 func TestWorkerPoolDoWait(t *testing.T) { wp := New(5) //设置最大线程数 @@ -87,9 +60,9 @@ func TestWorkerPoolDoWait(t *testing.T) { wp.DoWait(func() error { for j := 0; j < 5; j++ { //每次打印0-10的值 fmt.Println(fmt.Sprintf("%v->\t%v", ii, j)) - if ii == 1 { - return errors.New("my test err") - } + // if ii == 1 { + // return errors.New("my test err") + // } time.Sleep(1 * time.Second) }