@@ -12,6 +12,7 @@ import (
12
12
"github.com/prometheus/client_golang/prometheus"
13
13
"github.com/schollz/progressbar/v3"
14
14
"github.com/spf13/afero"
15
+ xslices "golang.org/x/exp/slices"
15
16
"golang.org/x/xerrors"
16
17
"tailscale.com/ipn/store"
17
18
"tailscale.com/net/netns"
@@ -20,6 +21,7 @@ import (
20
21
cslog "cdr.dev/slog"
21
22
csloghuman "cdr.dev/slog/sloggers/sloghuman"
22
23
"github.com/coder/coder/v2/agent/agentssh"
24
+ "github.com/coder/pretty"
23
25
"github.com/coder/serpent"
24
26
"github.com/coder/wush/cliui"
25
27
"github.com/coder/wush/overlay"
@@ -30,6 +32,8 @@ func serveCmd() *serpent.Command {
30
32
var (
31
33
overlayType string
32
34
verbose bool
35
+ enabled = []string {}
36
+ disabled = []string {}
33
37
)
34
38
return & serpent.Command {
35
39
Use : "serve" ,
@@ -89,72 +93,64 @@ func serveCmd() *serpent.Command {
89
93
90
94
fmt .Println (cliui .Timestamp (time .Now ()), "WireGuard is ready" )
91
95
92
- sshSrv , err := agentssh .NewServer (ctx ,
93
- cslog .Make (csloghuman .Sink (logSink )),
94
- prometheus .NewRegistry (),
95
- fs ,
96
- nil ,
97
- )
98
- if err != nil {
99
- return err
100
- }
101
-
102
- sshListener , err := ts .Listen ("tcp" , ":3" )
103
- if err != nil {
104
- return err
105
- }
96
+ closers := []io.Closer {}
106
97
107
- go func () {
108
- fmt .Println (cliui .Timestamp (time .Now ()), "SSH server listening" )
109
- err := sshSrv .Serve (sshListener )
98
+ if xslices .Contains (enabled , "ssh" ) && ! xslices .Contains (disabled , "ssh" ) {
99
+ sshSrv , err := agentssh .NewServer (ctx ,
100
+ cslog .Make (csloghuman .Sink (logSink )),
101
+ prometheus .NewRegistry (),
102
+ fs ,
103
+ nil ,
104
+ )
110
105
if err != nil {
111
- logger .Info ("ssh server exited" , "err" , err )
112
- }
113
- }()
114
-
115
- cpListener , err := ts .Listen ("tcp" , ":4444" )
116
- if err != nil {
117
- return err
118
- }
119
-
120
- go http .Serve (cpListener , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
121
- if r .Method != "POST" {
122
- w .WriteHeader (http .StatusOK )
123
- w .Write ([]byte ("OK" ))
124
- return
106
+ return err
125
107
}
108
+ closers = append (closers , sshSrv )
126
109
127
- fiName := strings .TrimPrefix (r .URL .Path , "/" )
128
- defer r .Body .Close ()
129
-
130
- fi , err := os .OpenFile (fiName , os .O_CREATE | os .O_RDWR | os .O_TRUNC , 0644 )
110
+ sshListener , err := ts .Listen ("tcp" , ":3" )
131
111
if err != nil {
132
- http .Error (w , err .Error (), http .StatusInternalServerError )
133
- return
112
+ return err
134
113
}
114
+ closers = append (closers , sshListener )
115
+
116
+ fmt .Println (cliui .Timestamp (time .Now ()), "SSH server " + pretty .Sprint (cliui .DefaultStyles .Enabled , "enabled" ))
117
+ go func () {
118
+ err := sshSrv .Serve (sshListener )
119
+ if err != nil {
120
+ fmt .Println (cliui .Timestamp (time .Now ()), "SSH server exited: " + err .Error ())
121
+ }
122
+ }()
123
+ } else {
124
+ fmt .Println (cliui .Timestamp (time .Now ()), "SSH server " + pretty .Sprint (cliui .DefaultStyles .Disabled , "disabled" ))
125
+ }
135
126
136
- bar := progressbar .DefaultBytes (
137
- r .ContentLength ,
138
- fmt .Sprintf ("Downloading %q" , fiName ),
139
- )
140
- _ , err = io .Copy (io .MultiWriter (fi , bar ), r .Body )
127
+ if xslices .Contains (enabled , "cp" ) && ! xslices .Contains (disabled , "cp" ) {
128
+ cpListener , err := ts .Listen ("tcp" , ":4444" )
141
129
if err != nil {
142
- http .Error (w , err .Error (), http .StatusInternalServerError )
143
- return
130
+ return err
144
131
}
145
- fi .Close ()
146
- bar .Close ()
147
-
148
- w .WriteHeader (http .StatusOK )
149
- w .Write ([]byte (fmt .Sprintf ("File %q written" , fiName )))
150
- fmt .Printf ("Received file %s from %s\n " , fiName , r .RemoteAddr )
151
- }))
132
+ closers = append ([]io.Closer {cpListener }, closers ... )
133
+
134
+ fmt .Println (cliui .Timestamp (time .Now ()), "File transfer server " + pretty .Sprint (cliui .DefaultStyles .Enabled , "enabled" ))
135
+ go func () {
136
+ err := http .Serve (cpListener , http .HandlerFunc (cpHandler ))
137
+ if err != nil {
138
+ fmt .Println (cliui .Timestamp (time .Now ()), "File transfer server exited: " + err .Error ())
139
+ }
140
+ }()
141
+ } else {
142
+ fmt .Println (cliui .Timestamp (time .Now ()), "File transfer server " + pretty .Sprint (cliui .DefaultStyles .Disabled , "disabled" ))
143
+ }
152
144
153
145
ctx , ctxCancel := inv .SignalNotifyContext (ctx , os .Interrupt )
154
146
defer ctxCancel ()
155
147
148
+ closers = append (closers , ts )
156
149
<- ctx .Done ()
157
- return sshSrv .Close ()
150
+ for _ , closer := range closers {
151
+ closer .Close ()
152
+ }
153
+ return nil
158
154
},
159
155
Options : []serpent.Option {
160
156
{
@@ -169,6 +165,18 @@ func serveCmd() *serpent.Command {
169
165
Default : "false" ,
170
166
Value : serpent .BoolOf (& verbose ),
171
167
},
168
+ {
169
+ Flag : "enable" ,
170
+ Description : "Server options to enable." ,
171
+ Default : "ssh,cp" ,
172
+ Value : serpent .EnumArrayOf (& enabled , "ssh" , "cp" ),
173
+ },
174
+ {
175
+ Flag : "disable" ,
176
+ Description : "Server options to disable." ,
177
+ Default : "" ,
178
+ Value : serpent .EnumArrayOf (& disabled , "ssh" , "cp" ),
179
+ },
172
180
},
173
181
}
174
182
}
@@ -198,3 +206,36 @@ func newTSNet(direction string) (*tsnet.Server, error) {
198
206
199
207
return srv , nil
200
208
}
209
+
210
+ func cpHandler (w http.ResponseWriter , r * http.Request ) {
211
+ if r .Method != "POST" {
212
+ w .WriteHeader (http .StatusOK )
213
+ w .Write ([]byte ("OK" ))
214
+ return
215
+ }
216
+
217
+ fiName := strings .TrimPrefix (r .URL .Path , "/" )
218
+ defer r .Body .Close ()
219
+
220
+ fi , err := os .OpenFile (fiName , os .O_CREATE | os .O_RDWR | os .O_TRUNC , 0644 )
221
+ if err != nil {
222
+ http .Error (w , err .Error (), http .StatusInternalServerError )
223
+ return
224
+ }
225
+
226
+ bar := progressbar .DefaultBytes (
227
+ r .ContentLength ,
228
+ fmt .Sprintf ("Downloading %q" , fiName ),
229
+ )
230
+ _ , err = io .Copy (io .MultiWriter (fi , bar ), r .Body )
231
+ if err != nil {
232
+ http .Error (w , err .Error (), http .StatusInternalServerError )
233
+ return
234
+ }
235
+ fi .Close ()
236
+ bar .Close ()
237
+
238
+ w .WriteHeader (http .StatusOK )
239
+ w .Write ([]byte (fmt .Sprintf ("File %q written" , fiName )))
240
+ fmt .Printf ("Received file %s from %s\n " , fiName , r .RemoteAddr )
241
+ }
0 commit comments