diff --git a/cmd/envbuilder/main.go b/cmd/envbuilder/main.go index 9159f84a..410e0897 100644 --- a/cmd/envbuilder/main.go +++ b/cmd/envbuilder/main.go @@ -37,15 +37,6 @@ func envbuilderCmd() serpent.Command { Options: o.CLI(), Handler: func(inv *serpent.Invocation) error { o.SetDefaults() - var preExecs []func() - preExec := func() { - for _, fn := range preExecs { - fn() - } - preExecs = nil - } - defer preExec() // Ensure cleanup in case of error. - o.Logger = log.New(os.Stderr, o.Verbose) if o.CoderAgentURL != "" { if o.CoderAgentToken == "" { @@ -58,10 +49,7 @@ func envbuilderCmd() serpent.Command { coderLog, closeLogs, err := log.Coder(inv.Context(), u, o.CoderAgentToken) if err == nil { o.Logger = log.Wrap(o.Logger, coderLog) - preExecs = append(preExecs, func() { - o.Logger(log.LevelInfo, "Closing logs") - closeLogs() - }) + defer closeLogs() // This adds the envbuilder subsystem. // If telemetry is enabled in a Coder deployment, // this will be reported and help us understand @@ -90,7 +78,7 @@ func envbuilderCmd() serpent.Command { return nil } - err := envbuilder.Run(inv.Context(), o, preExec) + err := envbuilder.Run(inv.Context(), o) if err != nil { o.Logger(log.LevelError, "error: %s", err) } diff --git a/envbuilder.go b/envbuilder.go index 94998165..683f6a54 100644 --- a/envbuilder.go +++ b/envbuilder.go @@ -84,9 +84,7 @@ type execArgsInfo struct { // Logger is the logf to use for all operations. // Filesystem is the filesystem to use for all operations. // Defaults to the host filesystem. -// preExec are any functions that should be called before exec'ing the init -// command. This is useful for ensuring that defers get run. -func Run(ctx context.Context, opts options.Options, preExec ...func()) error { +func Run(ctx context.Context, opts options.Options) error { var args execArgsInfo // Run in a separate function to ensure all defers run before we // setuid or exec. @@ -105,9 +103,6 @@ func Run(ctx context.Context, opts options.Options, preExec ...func()) error { } opts.Logger(log.LevelInfo, "=== Running the init command %s %+v as the %q user...", opts.InitCommand, args.InitArgs, args.UserInfo.user.Username) - for _, fn := range preExec { - fn() - } err = syscall.Exec(args.InitCommand, append([]string{args.InitCommand}, args.InitArgs...), args.Environ) if err != nil { diff --git a/go.mod b/go.mod index 9fa1d696..b3fa7843 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,6 @@ require ( github.com/gliderlabs/ssh v0.3.7 github.com/go-git/go-billy/v5 v5.5.0 github.com/go-git/go-git/v5 v5.12.0 - github.com/google/go-cmp v0.6.0 github.com/google/go-containerregistry v0.20.1 github.com/google/uuid v1.6.0 github.com/hashicorp/go-multierror v1.1.1 @@ -150,6 +149,7 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/btree v1.1.2 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/nftables v0.2.0 // indirect github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b // indirect github.com/gorilla/handlers v1.5.1 // indirect diff --git a/integration/integration_test.go b/integration/integration_test.go index 79b678d5..66dfe846 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -23,8 +23,6 @@ import ( "testing" "time" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/envbuilder" "github.com/coder/envbuilder/devcontainer/features" "github.com/coder/envbuilder/internal/magicdir" @@ -60,71 +58,6 @@ const ( testImageUbuntu = "localhost:5000/envbuilder-test-ubuntu:latest" ) -func TestLogs(t *testing.T) { - t.Parallel() - - token := uuid.NewString() - logsDone := make(chan struct{}) - - logHandler := func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/api/v2/buildinfo": - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"version": "v2.8.9"}`)) - return - case "/api/v2/workspaceagents/me/logs": - w.WriteHeader(http.StatusOK) - tokHdr := r.Header.Get(codersdk.SessionTokenHeader) - assert.Equal(t, token, tokHdr) - var req agentsdk.PatchLogs - err := json.NewDecoder(r.Body).Decode(&req) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - for _, log := range req.Logs { - t.Logf("got log: %+v", log) - if strings.Contains(log.Output, "Closing logs") { - close(logsDone) - return - } - } - return - default: - t.Errorf("unexpected request to %s", r.URL.Path) - w.WriteHeader(http.StatusNotFound) - return - } - } - logSrv := httptest.NewServer(http.HandlerFunc(logHandler)) - defer logSrv.Close() - - // Ensures that a Git repository with a devcontainer.json is cloned and built. - srv := gittest.CreateGitServer(t, gittest.Options{ - Files: map[string]string{ - "devcontainer.json": `{ - "build": { - "dockerfile": "Dockerfile" - }, - }`, - "Dockerfile": fmt.Sprintf(`FROM %s`, testImageUbuntu), - }, - }) - _, err := runEnvbuilder(t, runOpts{env: []string{ - envbuilderEnv("GIT_URL", srv.URL), - "CODER_AGENT_URL=" + logSrv.URL, - "CODER_AGENT_TOKEN=" + token, - }}) - require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - select { - case <-ctx.Done(): - t.Fatal("timed out waiting for logs") - case <-logsDone: - } -} - func TestInitScriptInitCommand(t *testing.T) { t.Parallel() diff --git a/log/coder.go b/log/coder.go index b551140d..d8b4fe0d 100644 --- a/log/coder.go +++ b/log/coder.go @@ -6,7 +6,6 @@ import ( "fmt" "net/url" "os" - "sync" "time" "cdr.dev/slog" @@ -28,14 +27,13 @@ var ( minAgentAPIV2 = "v2.9" ) -// Coder establishes a connection to the Coder instance located at coderURL and -// authenticates using token. It then establishes a dRPC connection to the Agent -// API and begins sending logs. If the version of Coder does not support the -// Agent API, it will fall back to using the PatchLogs endpoint. The closer is -// used to close the logger and to wait at most logSendGracePeriod for logs to -// be sent. Cancelling the context will close the logs immediately without -// waiting for logs to be sent. -func Coder(ctx context.Context, coderURL *url.URL, token string) (logger Func, closer func(), err error) { +// Coder establishes a connection to the Coder instance located at +// coderURL and authenticates using token. It then establishes a +// dRPC connection to the Agent API and begins sending logs. +// If the version of Coder does not support the Agent API, it will +// fall back to using the PatchLogs endpoint. +// The returned function is used to block until all logs are sent. +func Coder(ctx context.Context, coderURL *url.URL, token string) (Func, func(), error) { // To troubleshoot issues, we need some way of logging. metaLogger := slog.Make(sloghuman.Sink(os.Stderr)) defer metaLogger.Sync() @@ -46,11 +44,9 @@ func Coder(ctx context.Context, coderURL *url.URL, token string) (logger Func, c } if semver.Compare(semver.MajorMinor(bi.Version), minAgentAPIV2) < 0 { metaLogger.Warn(ctx, "Detected Coder version incompatible with AgentAPI v2, falling back to deprecated API", slog.F("coder_version", bi.Version)) - logger, closer = sendLogsV1(ctx, client, metaLogger.Named("send_logs_v1")) - return logger, closer, nil + sendLogs, flushLogs := sendLogsV1(ctx, client, metaLogger.Named("send_logs_v1")) + return sendLogs, flushLogs, nil } - // Note that ctx passed to initRPC will be inherited by the - // underlying connection, nothing we can do about that here. dac, err := initRPC(ctx, client, metaLogger.Named("init_rpc")) if err != nil { // Logged externally @@ -58,14 +54,8 @@ func Coder(ctx context.Context, coderURL *url.URL, token string) (logger Func, c } ls := agentsdk.NewLogSender(metaLogger.Named("coder_log_sender")) metaLogger.Warn(ctx, "Sending logs via AgentAPI v2", slog.F("coder_version", bi.Version)) - logger, closer = sendLogsV2(ctx, dac, ls, metaLogger.Named("send_logs_v2")) - var closeOnce sync.Once - return logger, func() { - closer() - closeOnce.Do(func() { - _ = dac.DRPCConn().Close() - }) - }, nil + sendLogs, doneFunc := sendLogsV2(ctx, dac, ls, metaLogger.Named("send_logs_v2")) + return sendLogs, doneFunc, nil } type coderLogSender interface { @@ -84,7 +74,7 @@ func initClient(coderURL *url.URL, token string) *agentsdk.Client { func initRPC(ctx context.Context, client *agentsdk.Client, l slog.Logger) (proto.DRPCAgentClient20, error) { var c proto.DRPCAgentClient20 var err error - retryCtx, retryCancel := context.WithTimeout(ctx, rpcConnectTimeout) + retryCtx, retryCancel := context.WithTimeout(context.Background(), rpcConnectTimeout) defer retryCancel() attempts := 0 for r := retry.New(100*time.Millisecond, time.Second); r.Wait(retryCtx); { @@ -105,67 +95,65 @@ func initRPC(ctx context.Context, client *agentsdk.Client, l slog.Logger) (proto // sendLogsV1 uses the PatchLogs endpoint to send logs. // This is deprecated, but required for backward compatibility with older versions of Coder. -func sendLogsV1(ctx context.Context, client *agentsdk.Client, l slog.Logger) (logger Func, closer func()) { +func sendLogsV1(ctx context.Context, client *agentsdk.Client, l slog.Logger) (Func, func()) { // nolint: staticcheck // required for backwards compatibility - sendLog, flushAndClose := agentsdk.LogsSender(agentsdk.ExternalLogSourceID, client.PatchLogs, slog.Logger{}) - var mu sync.Mutex + sendLogs, flushLogs := agentsdk.LogsSender(agentsdk.ExternalLogSourceID, client.PatchLogs, slog.Logger{}) return func(lvl Level, msg string, args ...any) { log := agentsdk.Log{ CreatedAt: time.Now(), Output: fmt.Sprintf(msg, args...), Level: codersdk.LogLevel(lvl), } - mu.Lock() - defer mu.Unlock() - if err := sendLog(ctx, log); err != nil { + if err := sendLogs(ctx, log); err != nil { l.Warn(ctx, "failed to send logs to Coder", slog.Error(err)) } }, func() { - ctx, cancel := context.WithTimeout(ctx, logSendGracePeriod) - defer cancel() - if err := flushAndClose(ctx); err != nil { + if err := flushLogs(ctx); err != nil { l.Warn(ctx, "failed to flush logs", slog.Error(err)) } } } // sendLogsV2 uses the v2 agent API to send logs. Only compatibile with coder versions >= 2.9. -func sendLogsV2(ctx context.Context, dest agentsdk.LogDest, ls coderLogSender, l slog.Logger) (logger Func, closer func()) { - sendCtx, sendCancel := context.WithCancel(ctx) +func sendLogsV2(ctx context.Context, dest agentsdk.LogDest, ls coderLogSender, l slog.Logger) (Func, func()) { done := make(chan struct{}) uid := uuid.New() go func() { defer close(done) - if err := ls.SendLoop(sendCtx, dest); err != nil { + if err := ls.SendLoop(ctx, dest); err != nil { if !errors.Is(err, context.Canceled) { l.Warn(ctx, "failed to send logs to Coder", slog.Error(err)) } } + + // Wait for up to 10 seconds for logs to finish sending. + sendCtx, sendCancel := context.WithTimeout(context.Background(), logSendGracePeriod) + defer sendCancel() + // Try once more to send any pending logs + if err := ls.SendLoop(sendCtx, dest); err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + l.Warn(ctx, "failed to send remaining logs to Coder", slog.Error(err)) + } + } + ls.Flush(uid) + if err := ls.WaitUntilEmpty(sendCtx); err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + l.Warn(ctx, "log sender did not empty", slog.Error(err)) + } + } }() - var closeOnce sync.Once - return func(l Level, msg string, args ...any) { - ls.Enqueue(uid, agentsdk.Log{ - CreatedAt: time.Now(), - Output: fmt.Sprintf(msg, args...), - Level: codersdk.LogLevel(l), - }) - }, func() { - closeOnce.Do(func() { - // Trigger a flush and wait for logs to be sent. - ls.Flush(uid) - ctx, cancel := context.WithTimeout(ctx, logSendGracePeriod) - defer cancel() - err := ls.WaitUntilEmpty(ctx) - if err != nil { - l.Warn(ctx, "log sender did not empty", slog.Error(err)) - } + logFunc := func(l Level, msg string, args ...any) { + ls.Enqueue(uid, agentsdk.Log{ + CreatedAt: time.Now(), + Output: fmt.Sprintf(msg, args...), + Level: codersdk.LogLevel(l), + }) + } - // Stop the send loop. - sendCancel() - }) + doneFunc := func() { + <-done + } - // Wait for the send loop to finish. - <-done - } + return logFunc, doneFunc } diff --git a/log/coder_internal_test.go b/log/coder_internal_test.go index 8b8bb632..4895150e 100644 --- a/log/coder_internal_test.go +++ b/log/coder_internal_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "math/rand" "net/http" "net/http/httptest" "net/url" @@ -39,8 +38,10 @@ func TestCoder(t *testing.T) { defer closeOnce.Do(func() { close(gotLogs) }) tokHdr := r.Header.Get(codersdk.SessionTokenHeader) assert.Equal(t, token, tokHdr) - req, ok := decodeV1Logs(t, w, r) - if !ok { + var req agentsdk.PatchLogs + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) return } if assert.Len(t, req.Logs, 1) { @@ -53,44 +54,15 @@ func TestCoder(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - - logger, _ := newCoderLogger(ctx, t, srv.URL, token) - logger(LevelInfo, "hello %s", "world") + u, err := url.Parse(srv.URL) + require.NoError(t, err) + log, closeLog, err := Coder(ctx, u, token) + require.NoError(t, err) + defer closeLog() + log(LevelInfo, "hello %s", "world") <-gotLogs }) - t.Run("V1/Close", func(t *testing.T) { - t.Parallel() - - var got []agentsdk.Log - handler := func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/api/v2/buildinfo" { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"version": "v2.8.9"}`)) - return - } - req, ok := decodeV1Logs(t, w, r) - if !ok { - return - } - got = append(got, req.Logs...) - } - srv := httptest.NewServer(http.HandlerFunc(handler)) - defer srv.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - logger, closer := newCoderLogger(ctx, t, srv.URL, uuid.NewString()) - logger(LevelInfo, "1") - logger(LevelInfo, "2") - closer() - logger(LevelInfo, "3") - require.Len(t, got, 2) - assert.Equal(t, "1", got[0].Output) - assert.Equal(t, "2", got[1].Output) - }) - t.Run("V1/ErrUnauthorized", func(t *testing.T) { t.Parallel() @@ -168,31 +140,42 @@ func TestCoder(t *testing.T) { require.Len(t, ld.logs, 10) }) - // In this test, we just fake out the DRPC server. - t.Run("V2/Close", func(t *testing.T) { + // In this test, we just stand up an endpoint that does not + // do dRPC. We'll try to connect, fail to websocket upgrade + // and eventually give up. + t.Run("V2/Err", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ld := &fakeLogDest{t: t} - ls := agentsdk.NewLogSender(slogtest.Make(t, nil)) - logger, closer := sendLogsV2(ctx, ld, ls, slogtest.Make(t, nil)) - defer closer() - - logger(LevelInfo, "1") - logger(LevelInfo, "2") - closer() - logger(LevelInfo, "3") + token := uuid.NewString() + handlerDone := make(chan struct{}) + var closeOnce sync.Once + handler := func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v2/buildinfo" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"version": "v2.9.0"}`)) + return + } + defer closeOnce.Do(func() { close(handlerDone) }) + w.WriteHeader(http.StatusOK) + } + srv := httptest.NewServer(http.HandlerFunc(handler)) + defer srv.Close() - require.Len(t, ld.logs, 2) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + u, err := url.Parse(srv.URL) + require.NoError(t, err) + _, _, err = Coder(ctx, u, token) + require.ErrorContains(t, err, "failed to WebSocket dial") + require.ErrorIs(t, err, context.DeadlineExceeded) + <-handlerDone }) // In this test, we validate that a 401 error on the initial connect // results in a retry. When envbuilder initially attempts to connect // using the Coder agent token, the workspace build may not yet have // completed. - t.Run("V2/Retry", func(t *testing.T) { + t.Run("V2Retry", func(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -238,99 +221,6 @@ func TestCoder(t *testing.T) { }) } -//nolint:paralleltest // We need to replace a global timeout. -func TestCoderRPCTimeout(t *testing.T) { - // This timeout is picked with the current subtests in mind, it - // should not be changed without good reason. - testReplaceTimeout(t, &rpcConnectTimeout, 500*time.Millisecond) - - // In this test, we just stand up an endpoint that does not - // do dRPC. We'll try to connect, fail to websocket upgrade - // and eventually give up after rpcConnectTimeout. - t.Run("V2/Err", func(t *testing.T) { - t.Parallel() - - token := uuid.NewString() - handlerDone := make(chan struct{}) - handlerWait := make(chan struct{}) - var closeOnce sync.Once - handler := func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/api/v2/buildinfo" { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"version": "v2.9.0"}`)) - return - } - defer closeOnce.Do(func() { close(handlerDone) }) - <-handlerWait - w.WriteHeader(http.StatusOK) - } - srv := httptest.NewServer(http.HandlerFunc(handler)) - defer srv.Close() - - ctx, cancel := context.WithTimeout(context.Background(), rpcConnectTimeout/2) - defer cancel() - u, err := url.Parse(srv.URL) - require.NoError(t, err) - _, _, err = Coder(ctx, u, token) - require.ErrorContains(t, err, "failed to WebSocket dial") - require.ErrorIs(t, err, context.DeadlineExceeded) - close(handlerWait) - <-handlerDone - }) - - t.Run("V2/Timeout", func(t *testing.T) { - t.Parallel() - - token := uuid.NewString() - handlerDone := make(chan struct{}) - handlerWait := make(chan struct{}) - var closeOnce sync.Once - handler := func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/api/v2/buildinfo" { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"version": "v2.9.0"}`)) - return - } - defer closeOnce.Do(func() { close(handlerDone) }) - <-handlerWait - w.WriteHeader(http.StatusOK) - } - srv := httptest.NewServer(http.HandlerFunc(handler)) - defer srv.Close() - - ctx, cancel := context.WithTimeout(context.Background(), rpcConnectTimeout*2) - defer cancel() - u, err := url.Parse(srv.URL) - require.NoError(t, err) - _, _, err = Coder(ctx, u, token) - require.ErrorContains(t, err, "failed to WebSocket dial") - require.ErrorIs(t, err, context.DeadlineExceeded) - close(handlerWait) - <-handlerDone - }) -} - -func decodeV1Logs(t *testing.T, w http.ResponseWriter, r *http.Request) (agentsdk.PatchLogs, bool) { - t.Helper() - var req agentsdk.PatchLogs - err := json.NewDecoder(r.Body).Decode(&req) - if !assert.NoError(t, err) { - http.Error(w, err.Error(), http.StatusBadRequest) - return req, false - } - return req, true -} - -func newCoderLogger(ctx context.Context, t *testing.T, us string, token string) (Func, func()) { - t.Helper() - u, err := url.Parse(us) - require.NoError(t, err) - logger, closer, err := Coder(ctx, u, token) - require.NoError(t, err) - t.Cleanup(closer) - return logger, closer -} - type fakeLogDest struct { t testing.TB logs []*proto.Log @@ -341,27 +231,3 @@ func (d *fakeLogDest) BatchCreateLogs(ctx context.Context, request *proto.BatchC d.logs = append(d.logs, request.Logs...) return &proto.BatchCreateLogsResponse{}, nil } - -func testReplaceTimeout(t *testing.T, v *time.Duration, d time.Duration) { - t.Helper() - if isParallel(t) { - t.Fatal("cannot replace timeout in parallel test") - } - old := *v - *v = d - t.Cleanup(func() { *v = old }) -} - -func isParallel(t *testing.T) (ret bool) { - t.Helper() - // This is a hack to determine if the test is running in parallel - // via property of t.Setenv. - defer func() { - if r := recover(); r != nil { - ret = true - } - }() - // Random variable name to avoid collisions. - t.Setenv(fmt.Sprintf("__TEST_CHECK_IS_PARALLEL_%d", rand.Int()), "1") - return false -}