Skip to content

Commit 80e0fd8

Browse files
committed
internal/mcp: support prompts
Add support for 'prompts/list' and 'prompts/get', and registering prompts on the server using AddPrompts. Add a 'MakePrompt' helper to construct prompts from a prompt handler using reflection. Change-Id: I479cb9c9b99313cca99640c1f5ea1939363759e6 Reviewed-on: https://go-review.googlesource.com/c/tools/+/669015 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Jonathan Amsterdam <[email protected]>
1 parent ab01700 commit 80e0fd8

File tree

17 files changed

+531
-98
lines changed

17 files changed

+531
-98
lines changed

internal/mcp/client.go

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,33 @@ func (sc *ServerConnection) Ping(ctx context.Context) error {
150150
return call(ctx, sc.conn, "ping", nil, nil)
151151
}
152152

153+
// ListPrompts lists prompts that are currently available on the server.
154+
func (sc *ServerConnection) ListPrompts(ctx context.Context) ([]protocol.Prompt, error) {
155+
var (
156+
params = &protocol.ListPromptsParams{}
157+
result protocol.ListPromptsResult
158+
)
159+
if err := call(ctx, sc.conn, "prompts/list", params, &result); err != nil {
160+
return nil, err
161+
}
162+
return result.Prompts, nil
163+
}
164+
165+
// GetPrompt gets a prompt from the server.
166+
func (sc *ServerConnection) GetPrompt(ctx context.Context, name string, args map[string]string) (*protocol.GetPromptResult, error) {
167+
var (
168+
params = &protocol.GetPromptParams{
169+
Name: name,
170+
Arguments: args,
171+
}
172+
result = &protocol.GetPromptResult{}
173+
)
174+
if err := call(ctx, sc.conn, "prompts/get", params, result); err != nil {
175+
return nil, err
176+
}
177+
return result, nil
178+
}
179+
153180
// ListTools lists tools that are currently available on the server.
154181
func (sc *ServerConnection) ListTools(ctx context.Context) ([]protocol.Tool, error) {
155182
var (
@@ -164,23 +191,27 @@ func (sc *ServerConnection) ListTools(ctx context.Context) ([]protocol.Tool, err
164191

165192
// CallTool calls the tool with the given name and arguments.
166193
//
167-
// TODO: make the following true:
194+
// TODO(jba): make the following true:
168195
// If the provided arguments do not conform to the schema for the given tool,
169196
// the call fails.
170-
func (sc *ServerConnection) CallTool(ctx context.Context, name string, args any) (_ []Content, err error) {
197+
func (sc *ServerConnection) CallTool(ctx context.Context, name string, args map[string]any) (_ []Content, err error) {
171198
defer func() {
172199
if err != nil {
173200
err = fmt.Errorf("calling tool %q: %w", name, err)
174201
}
175202
}()
176-
argJSON, err := json.Marshal(args)
177-
if err != nil {
178-
return nil, fmt.Errorf("marshaling args: %v", err)
203+
argsJSON := make(map[string]json.RawMessage)
204+
for name, arg := range args {
205+
argJSON, err := json.Marshal(arg)
206+
if err != nil {
207+
return nil, fmt.Errorf("marshaling argument %s: %v", name, err)
208+
}
209+
argsJSON[name] = argJSON
179210
}
180211
var (
181212
params = &protocol.CallToolParams{
182213
Name: name,
183-
Arguments: argJSON,
214+
Arguments: argsJSON,
184215
}
185216
result protocol.CallToolResult
186217
)

internal/mcp/cmd_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func TestCmdTransport(t *testing.T) {
5353
if err != nil {
5454
log.Fatal(err)
5555
}
56-
got, err := serverConn.CallTool(ctx, "greet", SayHiParams{Name: "user"})
56+
got, err := serverConn.CallTool(ctx, "greet", map[string]any{"name": "user"})
5757
if err != nil {
5858
log.Fatal(err)
5959
}

internal/mcp/examples/hello/main.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,54 @@ package main
66

77
import (
88
"context"
9+
"encoding/json"
910
"flag"
1011
"fmt"
1112
"net/http"
1213
"os"
1314

1415
"golang.org/x/tools/internal/mcp"
16+
"golang.org/x/tools/internal/mcp/internal/protocol"
1517
)
1618

1719
var httpAddr = flag.String("http", "", "if set, use SSE HTTP at this address, instead of stdin/stdout")
1820

19-
type SayHiParams struct {
21+
type HiParams struct {
2022
Name string `json:"name"`
2123
}
2224

23-
func SayHi(ctx context.Context, cc *mcp.ClientConnection, params *SayHiParams) ([]mcp.Content, error) {
25+
func SayHi(ctx context.Context, cc *mcp.ClientConnection, params *HiParams) ([]mcp.Content, error) {
2426
return []mcp.Content{
2527
mcp.TextContent{Text: "Hi " + params.Name},
2628
}, nil
2729
}
2830

31+
func PromptHi(ctx context.Context, cc *mcp.ClientConnection, params *HiParams) (*protocol.GetPromptResult, error) {
32+
// (see related TODOs about cleaning up content construction)
33+
content, err := json.Marshal(protocol.TextContent{
34+
Type: "text",
35+
Text: "Say hi to " + params.Name,
36+
})
37+
if err != nil {
38+
return nil, err
39+
}
40+
return &protocol.GetPromptResult{
41+
Description: "Code review prompt",
42+
Messages: []protocol.PromptMessage{
43+
// TODO: move 'Content' to the protocol package.
44+
{Role: "user", Content: json.RawMessage(content)},
45+
},
46+
}, nil
47+
}
48+
2949
func main() {
3050
flag.Parse()
3151

3252
server := mcp.NewServer("greeter", "v0.0.1", nil)
3353
server.AddTools(mcp.MakeTool("greet", "say hi", SayHi, mcp.Input(
3454
mcp.Property("name", mcp.Description("the name to say hi to")),
3555
)))
56+
server.AddPrompts(mcp.MakePrompt("greet", "", PromptHi))
3657

3758
if *httpAddr != "" {
3859
handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server {

internal/mcp/internal/jsonschema/infer.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ func parseField(f reflect.StructField) (name string, required, include bool) {
142142
}
143143
name = props[0]
144144
}
145+
// TODO: support 'omitzero' as well.
145146
required = !slices.Contains(props[1:], "omitempty")
146147
}
147148
return name, required, true

internal/mcp/internal/protocol/generate.go

Lines changed: 24 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@ package main
1313

1414
import (
1515
"bytes"
16-
"cmp"
1716
"encoding/json"
1817
"flag"
1918
"fmt"
2019
"go/format"
2120
"io"
22-
"iter"
2321
"log"
2422
"net/http"
2523
"os"
@@ -28,6 +26,7 @@ import (
2826
"strings"
2927

3028
"golang.org/x/tools/internal/mcp/internal/jsonschema"
29+
"golang.org/x/tools/internal/mcp/internal/util"
3130
)
3231

3332
var schemaFile = flag.String("schema_file", "", "if set, use this file as the persistent schema file")
@@ -54,31 +53,36 @@ var declarations = config{
5453
"CallToolRequest": {
5554
Fields: config{"Params": {Name: "CallToolParams"}},
5655
},
57-
"CallToolResult": {
58-
Name: "CallToolResult",
59-
},
56+
"CallToolResult": {Name: "CallToolResult"},
6057
"CancelledNotification": {
6158
Fields: config{"Params": {Name: "CancelledParams"}},
6259
},
6360
"ClientCapabilities": {Name: "ClientCapabilities"},
64-
"Implementation": {Name: "Implementation"},
61+
"GetPromptRequest": {
62+
Fields: config{"Params": {Name: "GetPromptParams"}},
63+
},
64+
"GetPromptResult": {Name: "GetPromptResult"},
65+
"Implementation": {Name: "Implementation"},
6566
"InitializeRequest": {
6667
Fields: config{"Params": {Name: "InitializeParams"}},
6768
},
68-
"InitializeResult": {
69-
Name: "InitializeResult",
70-
},
69+
"InitializeResult": {Name: "InitializeResult"},
7170
"InitializedNotification": {
7271
Fields: config{"Params": {Name: "InitializedParams"}},
7372
},
73+
"ListPromptsRequest": {
74+
Fields: config{"Params": {Name: "ListPromptsParams"}},
75+
},
76+
"ListPromptsResult": {Name: "ListPromptsResult"},
7477
"ListToolsRequest": {
7578
Fields: config{"Params": {Name: "ListToolsParams"}},
7679
},
77-
"ListToolsResult": {
78-
Name: "ListToolsResult",
79-
},
80-
"RequestId": {Substitute: "any"}, // null|number|string
81-
"Role": {Name: "Role"},
80+
"ListToolsResult": {Name: "ListToolsResult"},
81+
"Prompt": {Name: "Prompt"},
82+
"PromptMessage": {Name: "PromptMessage"},
83+
"PromptArgument": {Name: "PromptArgument"},
84+
"RequestId": {Substitute: "any"}, // null|number|string
85+
"Role": {Name: "Role"},
8286
"ServerCapabilities": {
8387
Name: "ServerCapabilities",
8488
Fields: config{
@@ -92,9 +96,7 @@ var declarations = config{
9296
Name: "Tool",
9397
Fields: config{"InputSchema": {Substitute: "*jsonschema.Schema"}},
9498
},
95-
"ToolAnnotations": {
96-
Name: "ToolAnnotations",
97-
},
99+
"ToolAnnotations": {Name: "ToolAnnotations"},
98100
}
99101

100102
func main() {
@@ -114,7 +116,7 @@ func main() {
114116
// writing types, we collect definitions and concatenate them later. This
115117
// also allows us to sort.
116118
named := make(map[string]*bytes.Buffer)
117-
for name, def := range sorted(schema.Definitions) {
119+
for name, def := range util.Sorted(schema.Definitions) {
118120
config := declarations[name]
119121
if config == nil {
120122
continue
@@ -142,7 +144,7 @@ import (
142144
`)
143145

144146
// Write out types.
145-
for _, b := range sorted(named) {
147+
for _, b := range util.Sorted(named) {
146148
fmt.Fprintln(buf)
147149
fmt.Fprint(buf, b.String())
148150
}
@@ -242,8 +244,8 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma
242244
// unmarshal them into a map[string]any, or delay unmarshalling with
243245
// json.RawMessage. For now, use json.RawMessage as it defers the choice.
244246
if def.Type == "object" && canHaveAdditionalProperties(def) {
245-
w.Write([]byte("json.RawMessage"))
246-
return nil
247+
w.Write([]byte("map[string]"))
248+
return writeType(w, nil, def.AdditionalProperties, named)
247249
}
248250

249251
if def.Type == "" {
@@ -269,7 +271,7 @@ func writeType(w io.Writer, config *typeConfig, def *jsonschema.Schema, named ma
269271

270272
case "object":
271273
fmt.Fprintf(w, "struct {\n")
272-
for name, fieldDef := range sorted(def.Properties) {
274+
for name, fieldDef := range util.Sorted(def.Properties) {
273275
if fieldDef.Description != "" {
274276
fmt.Fprintf(w, "%s\n", toComment(fieldDef.Description))
275277
}
@@ -385,28 +387,3 @@ func assert(cond bool, msg string) {
385387
panic(msg)
386388
}
387389
}
388-
389-
// Helpers below are copied from gopls' moremaps package.
390-
391-
// sorted returns an iterator over the entries of m in key order.
392-
func sorted[M ~map[K]V, K cmp.Ordered, V any](m M) iter.Seq2[K, V] {
393-
// TODO(adonovan): use maps.Sorted if proposal #68598 is accepted.
394-
return func(yield func(K, V) bool) {
395-
keys := keySlice(m)
396-
slices.Sort(keys)
397-
for _, k := range keys {
398-
if !yield(k, m[k]) {
399-
break
400-
}
401-
}
402-
}
403-
}
404-
405-
// keySlice returns the keys of the map M, like slices.Collect(maps.Keys(m)).
406-
func keySlice[M ~map[K]V, K comparable, V any](m M) []K {
407-
r := make([]K, 0, len(m))
408-
for k := range m {
409-
r = append(r, k)
410-
}
411-
return r
412-
}

0 commit comments

Comments
 (0)