Skip to content

Commit cb319e1

Browse files
authored
Merge pull request #2615 from alexandear/refactor/errgroup-withcontext
Refactor to use errgroup.WithContext
2 parents c2047d7 + 67cfe8b commit cb319e1

File tree

1 file changed

+25
-34
lines changed

1 file changed

+25
-34
lines changed

pkg/portfwd/client.go

+25-34
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,38 @@ import (
1010
"github.com/lima-vm/lima/pkg/guestagent/api"
1111
guestagentclient "github.com/lima-vm/lima/pkg/guestagent/api/client"
1212
"github.com/sirupsen/logrus"
13+
"golang.org/x/sync/errgroup"
1314
)
1415

1516
func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.Conn, guestAddr string) {
1617
defer conn.Close()
1718

1819
id := fmt.Sprintf("tcp-%s-%s", conn.LocalAddr().String(), conn.RemoteAddr().String())
19-
errCh := make(chan error, 2)
2020

2121
stream, err := client.Tunnel(ctx)
2222
if err != nil {
2323
logrus.Errorf("could not open tcp tunnel for id: %s error:%v", id, err)
2424
}
2525

26+
g, _ := errgroup.WithContext(ctx)
27+
2628
rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr}
27-
go func() {
29+
g.Go(func() error {
2830
_, err := io.Copy(rw, conn)
2931
if errors.Is(err, io.EOF) {
30-
errCh <- nil
31-
return
32+
return nil
3233
}
33-
errCh <- err
34-
}()
35-
go func() {
34+
return err
35+
})
36+
g.Go(func() error {
3637
_, err := io.Copy(conn, rw)
3738
if errors.Is(err, io.EOF) {
38-
errCh <- nil
39-
return
39+
return nil
4040
}
41-
errCh <- err
42-
}()
41+
return err
42+
})
4343

44-
err = <-errCh
45-
if err != nil {
44+
if err := g.Wait(); err != nil {
4645
logrus.Debugf("error in tcp tunnel for id: %s error:%v", id, err)
4746
}
4847
}
@@ -57,19 +56,17 @@ func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgen
5756
logrus.Errorf("could not open udp tunnel for id: %s error:%v", id, err)
5857
}
5958

60-
errCh := make(chan error, 2)
59+
g, _ := errgroup.WithContext(ctx)
6160

62-
go func() {
61+
g.Go(func() error {
6362
buf := make([]byte, 65507)
6463
for {
6564
n, addr, err := conn.ReadFrom(buf)
6665
if errors.Is(err, io.EOF) {
67-
errCh <- nil
68-
return
66+
return nil
6967
}
7068
if err != nil {
71-
errCh <- err
72-
return
69+
return err
7370
}
7471
msg := &api.TunnelMessage{
7572
Id: id + "-" + addr.String(),
@@ -79,38 +76,32 @@ func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgen
7976
UdpTargetAddr: addr.String(),
8077
}
8178
if err := stream.Send(msg); err != nil {
82-
errCh <- err
83-
return
79+
return err
8480
}
8581
}
86-
}()
82+
})
8783

88-
go func() {
84+
g.Go(func() error {
8985
for {
9086
in, err := stream.Recv()
9187
if errors.Is(err, io.EOF) {
92-
errCh <- nil
93-
return
88+
return nil
9489
}
9590
if err != nil {
96-
errCh <- err
97-
return
91+
return err
9892
}
9993
addr, err := net.ResolveUDPAddr("udp", in.UdpTargetAddr)
10094
if err != nil {
101-
errCh <- err
102-
return
95+
return err
10396
}
10497
_, err = conn.WriteTo(in.Data, addr)
10598
if err != nil {
106-
errCh <- err
107-
return
99+
return err
108100
}
109101
}
110-
}()
102+
})
111103

112-
err = <-errCh
113-
if err != nil {
104+
if err := g.Wait(); err != nil {
114105
logrus.Debugf("error in udp tunnel for id: %s error:%v", id, err)
115106
}
116107
}

0 commit comments

Comments
 (0)