Skip to content

Commit ba2af1b

Browse files
committed
feat: add enable/disable server options in wush serve command
Added command-line flags to enable or disable specific server options in the `wush serve` command. Updated `serve.go` to conditionally start the SSH and file transfer servers based on these flags and refactored the initialization and shutdown logic accordingly. Enhanced `cliui` with new styles for enabled and disabled states for clearer UI feedback.
1 parent d717f84 commit ba2af1b

File tree

2 files changed

+102
-53
lines changed

2 files changed

+102
-53
lines changed

cliui/cliui.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ type Styles struct {
2828
FocusedPrompt,
2929
Fuchsia,
3030
Warn,
31-
Wrap pretty.Style
31+
Wrap,
32+
Disabled,
33+
Enabled pretty.Style
3234
}
3335

3436
var (
@@ -153,6 +155,12 @@ func init() {
153155
Wrap: pretty.Style{
154156
pretty.LineWrap(80),
155157
},
158+
Disabled: pretty.Style{
159+
pretty.FgColor(Red),
160+
},
161+
Enabled: pretty.Style{
162+
pretty.FgColor(Green),
163+
},
156164
}
157165

158166
DefaultStyles.FocusedPrompt = append(

cmd/wush/serve.go

Lines changed: 93 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/prometheus/client_golang/prometheus"
1313
"github.com/schollz/progressbar/v3"
1414
"github.com/spf13/afero"
15+
xslices "golang.org/x/exp/slices"
1516
"golang.org/x/xerrors"
1617
"tailscale.com/ipn/store"
1718
"tailscale.com/net/netns"
@@ -20,6 +21,7 @@ import (
2021
cslog "cdr.dev/slog"
2122
csloghuman "cdr.dev/slog/sloggers/sloghuman"
2223
"github.com/coder/coder/v2/agent/agentssh"
24+
"github.com/coder/pretty"
2325
"github.com/coder/serpent"
2426
"github.com/coder/wush/cliui"
2527
"github.com/coder/wush/overlay"
@@ -30,6 +32,8 @@ func serveCmd() *serpent.Command {
3032
var (
3133
overlayType string
3234
verbose bool
35+
enabled = []string{}
36+
disabled = []string{}
3337
)
3438
return &serpent.Command{
3539
Use: "serve",
@@ -89,72 +93,64 @@ func serveCmd() *serpent.Command {
8993

9094
fmt.Println(cliui.Timestamp(time.Now()), "WireGuard is ready")
9195

92-
sshSrv, err := agentssh.NewServer(ctx,
93-
cslog.Make(csloghuman.Sink(logSink)),
94-
prometheus.NewRegistry(),
95-
fs,
96-
nil,
97-
)
98-
if err != nil {
99-
return err
100-
}
101-
102-
sshListener, err := ts.Listen("tcp", ":3")
103-
if err != nil {
104-
return err
105-
}
96+
closers := []io.Closer{}
10697

107-
go func() {
108-
fmt.Println(cliui.Timestamp(time.Now()), "SSH server listening")
109-
err := sshSrv.Serve(sshListener)
98+
if xslices.Contains(enabled, "ssh") && !xslices.Contains(disabled, "ssh") {
99+
sshSrv, err := agentssh.NewServer(ctx,
100+
cslog.Make(csloghuman.Sink(logSink)),
101+
prometheus.NewRegistry(),
102+
fs,
103+
nil,
104+
)
110105
if err != nil {
111-
logger.Info("ssh server exited", "err", err)
112-
}
113-
}()
114-
115-
cpListener, err := ts.Listen("tcp", ":4444")
116-
if err != nil {
117-
return err
118-
}
119-
120-
go http.Serve(cpListener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
121-
if r.Method != "POST" {
122-
w.WriteHeader(http.StatusOK)
123-
w.Write([]byte("OK"))
124-
return
106+
return err
125107
}
108+
closers = append(closers, sshSrv)
126109

127-
fiName := strings.TrimPrefix(r.URL.Path, "/")
128-
defer r.Body.Close()
129-
130-
fi, err := os.OpenFile(fiName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
110+
sshListener, err := ts.Listen("tcp", ":3")
131111
if err != nil {
132-
http.Error(w, err.Error(), http.StatusInternalServerError)
133-
return
112+
return err
134113
}
114+
closers = append(closers, sshListener)
115+
116+
fmt.Println(cliui.Timestamp(time.Now()), "SSH server "+pretty.Sprint(cliui.DefaultStyles.Enabled, "enabled"))
117+
go func() {
118+
err := sshSrv.Serve(sshListener)
119+
if err != nil {
120+
fmt.Println(cliui.Timestamp(time.Now()), "SSH server exited: "+err.Error())
121+
}
122+
}()
123+
} else {
124+
fmt.Println(cliui.Timestamp(time.Now()), "SSH server "+pretty.Sprint(cliui.DefaultStyles.Disabled, "disabled"))
125+
}
135126

136-
bar := progressbar.DefaultBytes(
137-
r.ContentLength,
138-
fmt.Sprintf("Downloading %q", fiName),
139-
)
140-
_, err = io.Copy(io.MultiWriter(fi, bar), r.Body)
127+
if xslices.Contains(enabled, "cp") && !xslices.Contains(disabled, "cp") {
128+
cpListener, err := ts.Listen("tcp", ":4444")
141129
if err != nil {
142-
http.Error(w, err.Error(), http.StatusInternalServerError)
143-
return
130+
return err
144131
}
145-
fi.Close()
146-
bar.Close()
147-
148-
w.WriteHeader(http.StatusOK)
149-
w.Write([]byte(fmt.Sprintf("File %q written", fiName)))
150-
fmt.Printf("Received file %s from %s\n", fiName, r.RemoteAddr)
151-
}))
132+
closers = append([]io.Closer{cpListener}, closers...)
133+
134+
fmt.Println(cliui.Timestamp(time.Now()), "File transfer server "+pretty.Sprint(cliui.DefaultStyles.Enabled, "enabled"))
135+
go func() {
136+
err := http.Serve(cpListener, http.HandlerFunc(cpHandler))
137+
if err != nil {
138+
fmt.Println(cliui.Timestamp(time.Now()), "File transfer server exited: "+err.Error())
139+
}
140+
}()
141+
} else {
142+
fmt.Println(cliui.Timestamp(time.Now()), "File transfer server "+pretty.Sprint(cliui.DefaultStyles.Disabled, "disabled"))
143+
}
152144

153145
ctx, ctxCancel := inv.SignalNotifyContext(ctx, os.Interrupt)
154146
defer ctxCancel()
155147

148+
closers = append(closers, ts)
156149
<-ctx.Done()
157-
return sshSrv.Close()
150+
for _, closer := range closers {
151+
closer.Close()
152+
}
153+
return nil
158154
},
159155
Options: []serpent.Option{
160156
{
@@ -169,6 +165,18 @@ func serveCmd() *serpent.Command {
169165
Default: "false",
170166
Value: serpent.BoolOf(&verbose),
171167
},
168+
{
169+
Flag: "enable",
170+
Description: "Server options to enable.",
171+
Default: "ssh,cp",
172+
Value: serpent.EnumArrayOf(&enabled, "ssh", "cp"),
173+
},
174+
{
175+
Flag: "disable",
176+
Description: "Server options to disable.",
177+
Default: "",
178+
Value: serpent.EnumArrayOf(&disabled, "ssh", "cp"),
179+
},
172180
},
173181
}
174182
}
@@ -198,3 +206,36 @@ func newTSNet(direction string) (*tsnet.Server, error) {
198206

199207
return srv, nil
200208
}
209+
210+
func cpHandler(w http.ResponseWriter, r *http.Request) {
211+
if r.Method != "POST" {
212+
w.WriteHeader(http.StatusOK)
213+
w.Write([]byte("OK"))
214+
return
215+
}
216+
217+
fiName := strings.TrimPrefix(r.URL.Path, "/")
218+
defer r.Body.Close()
219+
220+
fi, err := os.OpenFile(fiName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
221+
if err != nil {
222+
http.Error(w, err.Error(), http.StatusInternalServerError)
223+
return
224+
}
225+
226+
bar := progressbar.DefaultBytes(
227+
r.ContentLength,
228+
fmt.Sprintf("Downloading %q", fiName),
229+
)
230+
_, err = io.Copy(io.MultiWriter(fi, bar), r.Body)
231+
if err != nil {
232+
http.Error(w, err.Error(), http.StatusInternalServerError)
233+
return
234+
}
235+
fi.Close()
236+
bar.Close()
237+
238+
w.WriteHeader(http.StatusOK)
239+
w.Write([]byte(fmt.Sprintf("File %q written", fiName)))
240+
fmt.Printf("Received file %s from %s\n", fiName, r.RemoteAddr)
241+
}

0 commit comments

Comments
 (0)