1
1
package completion
2
2
3
3
import (
4
+ "bytes"
5
+ "errors"
4
6
"fmt"
5
7
"io"
8
+ "io/fs"
6
9
"os"
7
10
"os/user"
8
11
"path/filepath"
@@ -13,11 +16,16 @@ import (
13
16
"github.com/coder/serpent"
14
17
)
15
18
19
+ const (
20
+ completionStartTemplate = `# ============ BEGIN {{.Name}} COMPLETION ============`
21
+ completionEndTemplate = `# ============ END {{.Name}} COMPLETION ==============`
22
+ )
23
+
16
24
type Shell interface {
17
25
Name () string
18
26
InstallPath () (string , error )
19
- UsesOwnFile () bool
20
27
WriteCompletion (io.Writer ) error
28
+ ProgramName () string
21
29
}
22
30
23
31
const (
@@ -77,56 +85,154 @@ func DetectUserShell(programName string) (Shell, error) {
77
85
return nil , fmt .Errorf ("default shell not found" )
78
86
}
79
87
80
- func generateCompletion (
81
- scriptTemplate string ,
82
- ) func (io.Writer , string ) error {
83
- return func (w io.Writer , programName string ) error {
84
- tmpl , err := template .New ("script" ).Parse (scriptTemplate )
85
- if err != nil {
86
- return fmt .Errorf ("parse template: %w" , err )
87
- }
88
-
89
- err = tmpl .Execute (
90
- w ,
91
- map [string ]string {
92
- "Name" : programName ,
93
- },
94
- )
95
- if err != nil {
96
- return fmt .Errorf ("execute template: %w" , err )
97
- }
88
+ func configTemplateWriter (
89
+ w io.Writer ,
90
+ cfgTemplate string ,
91
+ programName string ,
92
+ ) error {
93
+ tmpl , err := template .New ("script" ).Parse (cfgTemplate )
94
+ if err != nil {
95
+ return fmt .Errorf ("parse template: %w" , err )
96
+ }
98
97
99
- return nil
98
+ err = tmpl .Execute (
99
+ w ,
100
+ map [string ]string {
101
+ "Name" : programName ,
102
+ },
103
+ )
104
+ if err != nil {
105
+ return fmt .Errorf ("execute template: %w" , err )
100
106
}
107
+
108
+ return nil
101
109
}
102
110
103
111
func InstallShellCompletion (shell Shell ) error {
104
112
path , err := shell .InstallPath ()
105
113
if err != nil {
106
114
return fmt .Errorf ("get install path: %w" , err )
107
115
}
116
+ var headerBuf bytes.Buffer
117
+ err = configTemplateWriter (& headerBuf , completionStartTemplate , shell .ProgramName ())
118
+ if err != nil {
119
+ return fmt .Errorf ("generate header: %w" , err )
120
+ }
121
+
122
+ var footerBytes bytes.Buffer
123
+ err = configTemplateWriter (& footerBytes , completionEndTemplate , shell .ProgramName ())
124
+ if err != nil {
125
+ return fmt .Errorf ("generate footer: %w" , err )
126
+ }
108
127
109
128
err = os .MkdirAll (filepath .Dir (path ), 0o755 )
110
129
if err != nil {
111
130
return fmt .Errorf ("create directories: %w" , err )
112
131
}
113
132
114
- if shell .UsesOwnFile () {
115
- err := os .WriteFile (path , nil , 0o644 )
133
+ f , err := os .ReadFile (path )
134
+ if err != nil && ! errors .Is (err , fs .ErrNotExist ) {
135
+ return fmt .Errorf ("read ssh config failed: %w" , err )
136
+ }
137
+
138
+ before , after , err := templateConfigSplit (headerBuf .Bytes (), footerBytes .Bytes (), f )
139
+ if err != nil {
140
+ return err
141
+ }
142
+
143
+ outBuf := bytes.Buffer {}
144
+ _ , _ = outBuf .Write (before )
145
+ if len (before ) > 0 {
146
+ _ , _ = outBuf .Write ([]byte ("\n " ))
147
+ }
148
+ _ , _ = outBuf .Write (headerBuf .Bytes ())
149
+ err = shell .WriteCompletion (& outBuf )
150
+ if err != nil {
151
+ return fmt .Errorf ("generate completion: %w" , err )
152
+ }
153
+ _ , _ = outBuf .Write (footerBytes .Bytes ())
154
+ _ , _ = outBuf .Write ([]byte ("\n " ))
155
+ _ , _ = outBuf .Write (after )
156
+
157
+ err = writeWithTempFileAndMove (path , & outBuf )
158
+ if err != nil {
159
+ return fmt .Errorf ("write completion: %w" , err )
160
+ }
161
+
162
+ return nil
163
+ }
164
+
165
+ func templateConfigSplit (header , footer , data []byte ) (before , after []byte , err error ) {
166
+ startCount := bytes .Count (data , header )
167
+ endCount := bytes .Count (data , footer )
168
+ if startCount > 1 || endCount > 1 {
169
+ return nil , nil , fmt .Errorf ("Malformed config file: multiple config sections" )
170
+ }
171
+
172
+ startIndex := bytes .Index (data , header )
173
+ endIndex := bytes .Index (data , footer )
174
+ if startIndex == - 1 && endIndex != - 1 {
175
+ return data , nil , fmt .Errorf ("Malformed config file: missing completion header" )
176
+ }
177
+ if startIndex != - 1 && endIndex == - 1 {
178
+ return data , nil , fmt .Errorf ("Malformed config file: missing completion footer" )
179
+ }
180
+ if startIndex != - 1 && endIndex != - 1 {
181
+ if startIndex > endIndex {
182
+ return data , nil , fmt .Errorf ("Malformed config file: completion header after footer" )
183
+ }
184
+ // Include leading and trailing newline, if present
185
+ start := startIndex
186
+ if start > 0 {
187
+ start --
188
+ }
189
+ end := endIndex + len (footer )
190
+ if end < len (data ) {
191
+ end ++
192
+ }
193
+ return data [:start ], data [end :], nil
194
+ }
195
+ return data , nil , nil
196
+ }
197
+
198
+ // writeWithTempFileAndMove writes to a temporary file in the same
199
+ // directory as path and renames the temp file to the file provided in
200
+ // path. This ensure we avoid trashing the file we are writing due to
201
+ // unforeseen circumstance like filesystem full, command killed, etc.
202
+ func writeWithTempFileAndMove (path string , r io.Reader ) (err error ) {
203
+ dir := filepath .Dir (path )
204
+ name := filepath .Base (path )
205
+
206
+ if err = os .MkdirAll (dir , 0o700 ); err != nil {
207
+ return fmt .Errorf ("create directory: %w" , err )
208
+ }
209
+
210
+ // Create a tempfile in the same directory for ensuring write
211
+ // operation does not fail.
212
+ f , err := os .CreateTemp (dir , fmt .Sprintf (".%s." , name ))
213
+ if err != nil {
214
+ return fmt .Errorf ("create temp file failed: %w" , err )
215
+ }
216
+ defer func () {
116
217
if err != nil {
117
- return fmt . Errorf ( "create file: %w" , err )
218
+ _ = os . Remove ( f . Name ()) // Cleanup in case a step failed.
118
219
}
220
+ }()
221
+
222
+ _ , err = io .Copy (f , r )
223
+ if err != nil {
224
+ _ = f .Close ()
225
+ return fmt .Errorf ("write temp file failed: %w" , err )
119
226
}
120
227
121
- f , err := os . OpenFile ( path , os . O_CREATE | os . O_APPEND | os . O_WRONLY , 0o644 )
228
+ err = f . Close ( )
122
229
if err != nil {
123
- return fmt .Errorf ("open file for appending : %w" , err )
230
+ return fmt .Errorf ("close temp file failed : %w" , err )
124
231
}
125
- defer f .Close ()
126
232
127
- err = shell . WriteCompletion ( f )
233
+ err = os . Rename ( f . Name (), path )
128
234
if err != nil {
129
- return fmt .Errorf ("write completion script : %w" , err )
235
+ return fmt .Errorf ("rename temp file failed : %w" , err )
130
236
}
131
237
132
238
return nil
0 commit comments