diff --git a/command.go b/command.go index 115c16a..9b32bff 100644 --- a/command.go +++ b/command.go @@ -59,6 +59,19 @@ type Command struct { Options OptionSet Annotations Annotations + // Tool is the name of the MCP tool this command provides. + // If set, the command can be invoked via MCP as a tool. + // Tool and Resource are mutually exclusive. + Tool string + + // ToolFlags is a set of flags to automatically set for a given MCP command. + ToolFlags []string + + // Resource is the URI of the MCP resource this command provides. + // If set, the command can be accessed via MCP as a resource. + // Tool and Resource are mutually exclusive. + Resource string + // Middleware is called before the Handler. // Use Chain() to combine multiple middlewares. Middleware MiddlewareFunc @@ -106,6 +119,11 @@ func (c *Command) init() error { } var merr error + // Validate that Tool and Resource are mutually exclusive + if c.Tool != "" && c.Resource != "" { + merr = errors.Join(merr, xerrors.Errorf("command %q cannot have both Tool and Resource set", c.Name())) + } + for i := range c.Options { opt := &c.Options[i] if opt.Name == "" { @@ -558,6 +576,12 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) { return -1, xerrors.Errorf("arg %s not found", want) } +// IsMCPEnabled returns true if the command is accessible via MCP +// (has either Tool or Resource field set) +func (c *Command) IsMCPEnabled() bool { + return c.Tool != "" || c.Resource != "" +} + // Run executes the command. // If two command share a flag name, the first command wins. // diff --git a/example/mcp/README.md b/example/mcp/README.md new file mode 100644 index 0000000..22e32b0 --- /dev/null +++ b/example/mcp/README.md @@ -0,0 +1,167 @@ +# Serpent MCP Server Example + +This example demonstrates how to use the Model Context Protocol (MCP) functionality in Serpent to create a command-line tool that can also be used as an MCP server. + +## What is MCP? + +The Model Context Protocol (MCP) is a protocol for communication between AI models and external tools or resources. It allows AI models to invoke tools and access resources provided by MCP servers. + +## How to Use + +### Running as a CLI Tool + +You can run the example as a normal CLI tool: + +```bash +# Echo a message +go run main.go echo "Hello, World!" + +# Get version information +go run main.go version + +# Show help +go run main.go --help +``` + +### Running as an MCP Server + +You can run the example as an MCP server using the `mcp` subcommand: + +```bash +go run main.go mcp +``` + +This will start an MCP server that listens on stdin/stdout for JSON-RPC 2.0 requests. + +## MCP Protocol + +### Lifecycle + +The MCP server follows the standard MCP lifecycle: + +1. The client sends an `initialize` request to the server +2. The server responds with its capabilities +3. The client sends an `initialized` notification +4. After this, normal message exchange can begin + +All MCP methods will return an error if called before the initialization process is complete. + +### Methods + +The MCP server implements the following JSON-RPC 2.0 methods: + +- `initialize`: Initializes the MCP server and returns its capabilities +- `notifications/initialized`: Notifies the server that initialization is complete +- `ping`: Simple ping method to check server availability +- `tools/list`: Lists all available tools +- `tools/call`: Invokes a tool with the given arguments +- `resources/list`: Lists all available resources +- `resources/templates/list`: Lists all available resource templates +- `resources/read`: Accesses a resource with the given URI + +### Example Requests + +Here are some example JSON-RPC 2.0 requests you can send to the MCP server: + +#### Initialize + +```json +{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","clientInfo":{"name":"manual-test-client","version":"1.0.0"},"capabilities":{}}} +``` + +Response: +```json +{"jsonrpc":"2.0","id":1,"result":{"capabilities":{"tools":true,"resources":true}}} +``` + +#### Initialized + +```json +{"jsonrpc":"2.0","id":2,"method":"notifications/initialized"} +``` + +#### List Tools + +```json +{"jsonrpc":"2.0","id":3,"method":"tools/list","params":{}} +``` + +#### List Resources + +```json +{"jsonrpc":"2.0","id":4,"method":"resources/list","params":{}} +``` + +#### Invoke Tool + +```json +{"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"name":"echo","arguments":{"_":"Hello from MCP!"}}} +``` + +#### Access Resource + +```json +{"jsonrpc":"2.0","id":6,"method":"resources/read","params":{"uri":"version"}} +``` + +### Complete Initialization Example + +Here's a complete example of the initialization process: + +```json +// Client sends initialize request +{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","clientInfo":{"name":"manual-test-client","version":"1.0.0"},"capabilities":{}}} + +// Server responds with capabilities +{"jsonrpc":"2.0","id":1,"result":{"capabilities":{"tools":true,"resources":true}}} + +// Client sends initialized notification +{"jsonrpc":"2.0","id":2,"method":"notifications/initialized"} + +// Server acknowledges (optional, since initialized is technically a notification) +{"jsonrpc":"2.0","id":2,"result":{}} + +// Now client can use MCP methods +{"jsonrpc":"2.0","id":3,"method":"tools/list","params":{}} +``` + +## How to Implement MCP in Your Own Commands + +To implement MCP in your own Serpent commands: + +1. Add the `Tool` field to commands that should be invokable as MCP tools +2. Add the `Resource` field to commands that should be accessible as MCP resources +3. Add the MCP command to your root command using `root.AddMCPCommand()` + +Example: + +```go +// Create a command that will be exposed as an MCP tool +echoCmd := &serpent.Command{ + Use: "echo [message]", + Short: "Echo a message", + Tool: "echo", // This makes the command available as an MCP tool + Handler: func(inv *serpent.Invocation) error { + // Command implementation + }, +} + +// Create a command that will be exposed as an MCP resource +versionCmd := &serpent.Command{ + Use: "version", + Short: "Get version information", + Resource: "version", // This makes the command available as an MCP resource + Handler: func(inv *serpent.Invocation) error { + // Command implementation + }, +} + +// Add the MCP command to the root command +root.AddSubcommands(serpent.MCPCommand()) +``` + +## Notes + +- A command can have either a `Tool` field or a `Resource` field, but not both +- Commands with neither `Tool` nor `Resource` set will not be accessible via MCP +- The MCP server communicates using JSON-RPC 2.0 over stdin/stdout diff --git a/example/mcp/main.go b/example/mcp/main.go new file mode 100644 index 0000000..45ff557 --- /dev/null +++ b/example/mcp/main.go @@ -0,0 +1,87 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/coder/serpent" +) + +func main() { + // Create a root command + root := &serpent.Command{ + Use: "mcp-example", + Short: "Example MCP server", + Long: "An example of how to use the MCP functionality in serpent.", + } + + var repeats int64 = 2 + + // Add a command that will be exposed as an MCP tool + echoCmd := &serpent.Command{ + Use: "echo [message]", + Short: "Echo a message", + Tool: "echo", // This makes the command available as an MCP tool + Options: []serpent.Option{ + { + Name: "repeat", + Flag: "repeat", // Add the Flag field so it's exposed in JSON Schema + Description: "Number of times to repeat the message.", + Default: "2", + Value: serpent.Int64Of(&repeats), + }, + }, + Handler: func(inv *serpent.Invocation) error { + message := "Hello, World!" + if len(inv.Args) > 0 { + message = strings.Join(inv.Args, " ") + } + for i := int64(0); i < repeats; i++ { + if _, err := fmt.Fprintln(inv.Stdout, message); err != nil { + return err + } + } + return nil + }, + } + root.AddSubcommands(echoCmd) + + // Add a command that will be exposed as an MCP resource + versionCmd := &serpent.Command{ + Use: "version", + Short: "Get version information", + Resource: "version", // This makes the command available as an MCP resource + Handler: func(inv *serpent.Invocation) error { + version := map[string]string{ + "version": "1.0.0", + "name": "serpent-mcp-example", + "author": "Coder", + } + encoder := json.NewEncoder(inv.Stdout) + return encoder.Encode(version) + }, + } + root.AddSubcommands(versionCmd) + + // Add a command that will not be exposed via MCP + hiddenCmd := &serpent.Command{ + Use: "hidden", + Short: "This command is not exposed via MCP", + Handler: func(inv *serpent.Invocation) error { + _, err := fmt.Fprintln(inv.Stdout, "This command is not exposed via MCP") + return err + }, + } + root.AddSubcommands(hiddenCmd) + + // Add the MCP command to the root command + root.AddSubcommands(serpent.MCPCommand()) + + // Run the command + if err := root.Invoke(os.Args[1:]...).WithOS().Run(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} diff --git a/go.mod b/go.mod index 1c2880c..932e80f 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,8 @@ require ( github.com/pion/udp v0.1.4 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 - golang.org/x/crypto v0.19.0 golang.org/x/exp v0.0.0-20240213143201-ec583247a57a + golang.org/x/term v0.17.0 golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 gopkg.in/yaml.v3 v3.0.1 ) @@ -24,7 +24,7 @@ require ( github.com/charmbracelet/lipgloss v0.8.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/go-logr/logr v1.4.1 // indirect - github.com/google/go-cmp v0.6.0 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect @@ -39,7 +39,6 @@ require ( go.opentelemetry.io/otel/trace v1.19.0 // indirect golang.org/x/net v0.21.0 // indirect golang.org/x/sys v0.17.0 // indirect - golang.org/x/term v0.17.0 // indirect google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f // indirect diff --git a/go.sum b/go.sum index a1106fc..b2a3b45 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -92,8 +94,6 @@ go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1 go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/exp v0.0.0-20240213143201-ec583247a57a h1:HinSgX1tJRX3KsL//Gxynpw5CTOAIPhgL4W8PNiIpVE= golang.org/x/exp v0.0.0-20240213143201-ec583247a57a/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= diff --git a/help.go b/help.go index 1753945..6e96194 100644 --- a/help.go +++ b/help.go @@ -15,7 +15,7 @@ import ( "github.com/mitchellh/go-wordwrap" "github.com/muesli/termenv" - "golang.org/x/crypto/ssh/terminal" + "golang.org/x/term" "golang.org/x/xerrors" "github.com/coder/pretty" @@ -31,7 +31,7 @@ type optionGroup struct { } func ttyWidth() int { - width, _, err := terminal.GetSize(0) + width, _, err := term.GetSize(0) if err != nil { return 80 } @@ -73,10 +73,8 @@ func prettyHeader(s string) string { } var defaultHelpTemplate = func() *template.Template { - var ( - optionFg = pretty.FgColor( - helpColor("#04A777"), - ) + optionFg := pretty.FgColor( + helpColor("#04A777"), ) return template.Must( template.New("usage").Funcs( diff --git a/help.tpl b/help.tpl index 3d4f1c9..c9b32fb 100644 --- a/help.tpl +++ b/help.tpl @@ -17,6 +17,14 @@ {{" Aliases: "}} {{- joinStrings .}} {{- end }} +{{ with .Tool }} +{{" MCP Tool: "}} {{- . }} +{{- end }} + +{{ with .Resource }} +{{" MCP Resource: "}} {{- . }} +{{- end }} + {{- with .Long}} {{"\n"}} {{- indent . 2}} diff --git a/mcp.go b/mcp.go new file mode 100644 index 0000000..d91eeba --- /dev/null +++ b/mcp.go @@ -0,0 +1,1048 @@ +package serpent + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "path" + "slices" + "strings" + + "golang.org/x/xerrors" +) + +// JSONRPC2 message types +type JSONRPC2Request struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type JSONRPC2Response struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *JSONRPC2Error `json:"error,omitempty"` +} + +type JSONRPC2Notification struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type JSONRPC2Error struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data,omitempty"` +} + +// MCP protocol message types +type InitializeParams struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` + } `json:"clientInfo"` + Capabilities map[string]any `json:"capabilities"` +} + +type InitializeResult struct { + ProtocolVersion string `json:"protocolVersion"` + ServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` + } `json:"serverInfo"` + Capabilities map[string]any `json:"capabilities"` +} + +type ListToolsParams struct { + Cursor string `json:"cursor,omitempty"` +} + +type ListToolsResult struct { + Tools []Tool `json:"tools"` + NextCursor string `json:"nextCursor,omitempty"` +} + +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"inputSchema,omitempty"` +} + +type CallToolParams struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` +} + +type CallToolResult struct { + Content []ToolContent `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +type ToolContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + MimeType string `json:"mimeType,omitempty"` + Data string `json:"data,omitempty"` +} + +type ListResourcesParams struct { + Cursor string `json:"cursor,omitempty"` +} + +type ListResourcesResult struct { + Resources []Resource `json:"resources"` + NextCursor string `json:"nextCursor,omitempty"` +} + +type Resource struct { + URI string `json:"uri"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + MimeType string `json:"mimeType,omitempty"` + Size int `json:"size,omitempty"` +} + +type ListResourceTemplatesParams struct { + Cursor string `json:"cursor,omitempty"` +} + +type ListResourceTemplatesResult struct { + ResourceTemplates []ResourceTemplate `json:"resourceTemplates"` + NextCursor string `json:"nextCursor,omitempty"` +} + +type ResourceTemplate struct { + URITemplate string `json:"uriTemplate"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + MimeType string `json:"mimeType,omitempty"` +} + +type ReadResourceParams struct { + URI string `json:"uri"` +} + +type ReadResourceResult struct { + Contents []ResourceContent `json:"contents"` +} + +type ResourceContent struct { + URI string `json:"uri"` + MimeType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` + Blob string `json:"blob,omitempty"` +} + +// JSON-RPC 2.0 error codes +const ( + // Standard JSON-RPC error codes + ErrorCodeParseError = -32700 + ErrorCodeInvalidRequest = -32600 + ErrorCodeMethodNotFound = -32601 + ErrorCodeInvalidParams = -32602 + ErrorCodeInternalError = -32603 + + // MCP specific error codes + ErrorCodeResourceNotFound = -32002 + ErrorCodeResourceUnavailable = -32001 + ErrorCodeToolNotFound = -32100 + ErrorCodeToolUnavailable = -32101 +) + +// MCPServer represents an MCP server that can handle tool invocations and resource access +type MCPServer struct { + rootCmd *Command + stdin io.Reader + stdout io.Writer + stderr io.Writer + cmdFinder CommandFinder + toolCmds map[string]*Command + toolFlags map[string][]string // overrides flags for given commands + resourceCmds map[string]*Command + resourceTemplates map[string]*Command // Maps URI templates to commands + initialized bool // Track if the server has been initialized + protocolVersion string // Protocol version negotiated during initialization +} + +// CommandFinder is a function that finds a command by name +type CommandFinder func(rootCmd *Command, name string) *Command + +// DefaultCommandFinder is the default implementation of CommandFinder +func DefaultCommandFinder(rootCmd *Command, name string) *Command { + parts := strings.Split(name, " ") + cmd := rootCmd + + for _, part := range parts { + found := false + for _, child := range cmd.Children { + if child.Name() == part { + cmd = child + found = true + break + } + } + if !found { + return nil + } + } + + return cmd +} + +// NewMCPServer creates a new MCP server +func NewMCPServer(rootCmd *Command, stdin io.Reader, stdout, stderr io.Writer) *MCPServer { + server := &MCPServer{ + rootCmd: rootCmd, + stdin: stdin, + stdout: stdout, + stderr: stderr, + cmdFinder: DefaultCommandFinder, + toolCmds: make(map[string]*Command), + toolFlags: make(map[string][]string), + resourceCmds: make(map[string]*Command), + resourceTemplates: make(map[string]*Command), + protocolVersion: "2025-03-26", // Default to latest version + } + + // Index all commands with Tool or Resource fields + rootCmd.Walk(func(cmd *Command) { + if cmd.Tool != "" { + server.toolCmds[cmd.Tool] = cmd + if len(cmd.ToolFlags) > 0 { + server.toolFlags[cmd.Tool] = cmd.ToolFlags[:] + } + } + if cmd.Resource != "" { + if strings.Contains(cmd.Resource, "{") && strings.Contains(cmd.Resource, "}") { + // This is a URI template + server.resourceTemplates[cmd.Resource] = cmd + } else { + // This is a static resource URI + server.resourceCmds[cmd.Resource] = cmd + } + } + }) + + return server +} + +// Run starts the MCP server +func (s *MCPServer) Run(ctx context.Context) error { + // Check if context is already done + select { + case <-ctx.Done(): + return nil + default: + // Continue with normal operation + } + + // Create a buffered reader for stdin + reader := bufio.NewReader(s.stdin) + + // Process requests until context is done or EOF + for { + // Check if context is done + select { + case <-ctx.Done(): + return nil + default: + // Continue processing + } + + // Try to read a line with a non-blocking approach + var line string + + // Use a channel to communicate when a line is read + lineCh := make(chan string, 1) + errCh := make(chan error, 1) + + // Start a goroutine to read a line + go func() { + text, err := reader.ReadString('\n') + if err != nil { + errCh <- err + return + } + lineCh <- strings.TrimSpace(text) + }() + + // Wait for either a line to be read, an error, or context cancellation + select { + case <-ctx.Done(): + // Context was canceled, exit gracefully + return nil + case err := <-errCh: + if err == io.EOF { + // End of input, exit normally + return nil + } + // Other error + return xerrors.Errorf("reading stdin: %w", err) + case line = <-lineCh: + // Line was read successfully, process it + } + + // Parse the JSON-RPC request + var req JSONRPC2Request + if err := json.Unmarshal([]byte(line), &req); err != nil { + s.sendErrorResponse(nil, ErrorCodeParseError, "Failed to parse JSON-RPC request", nil) + continue + } + + // Ensure this is a JSON-RPC 2.0 request + if req.JSONRPC != "2.0" { + s.sendErrorResponse(req.ID, ErrorCodeInvalidRequest, "Invalid JSON-RPC version, expected 2.0", nil) + continue + } + + // Handle the request based on the method + switch req.Method { + case "initialize": + s.handleInitialize(req) + case "notifications/initialized": + s.handleInitialized(req) + case "ping": + s.handlePing(req) + case "tools/list": + if !s.initialized { + s.sendErrorResponse(req.ID, ErrorCodeInvalidRequest, "Server not initialized", nil) + continue + } + s.handleListTools(req) + case "tools/call": + if !s.initialized { + s.sendErrorResponse(req.ID, ErrorCodeInvalidRequest, "Server not initialized", nil) + continue + } + s.handleCallTool(req) + case "resources/list": + if !s.initialized { + s.sendErrorResponse(req.ID, ErrorCodeInvalidRequest, "Server not initialized", nil) + continue + } + s.handleListResources(req) + case "resources/templates/list": + if !s.initialized { + s.sendErrorResponse(req.ID, ErrorCodeInvalidRequest, "Server not initialized", nil) + continue + } + s.handleListResourceTemplates(req) + case "resources/read": + if !s.initialized { + s.sendErrorResponse(req.ID, ErrorCodeInvalidRequest, "Server not initialized", nil) + continue + } + s.handleReadResource(req) + default: + s.sendErrorResponse(req.ID, ErrorCodeMethodNotFound, fmt.Sprintf("Method not found: %s", req.Method), nil) + } + } +} + +// handlePing handles the ping method, responding with an empty result +func (s *MCPServer) handlePing(req JSONRPC2Request) { + s.sendSuccessResponse(req.ID, struct{}{}) +} + +// handleListTools handles the tools/list method +func (s *MCPServer) handleListTools(req JSONRPC2Request) { + if _, err := UnmarshalParamsLenient[ListToolsParams](req.Params); err != nil { + s.sendErrorResponse(req.ID, ErrorCodeInvalidParams, "Invalid parameters", nil) + return + } + + tools := make([]Tool, 0, len(s.toolCmds)) + for name, cmd := range s.toolCmds { + // Generate a proper JSON Schema from the command's options + schema, err := s.generateJSONSchema(cmd) + if err != nil { + fmt.Fprintf(s.stderr, "Failed to generate schema for tool %s: %v\n", name, err) + schema = json.RawMessage(`{}`) + } + + tools = append(tools, Tool{ + Name: name, + Description: cmd.Use + " -- " + cmd.Short, + InputSchema: schema, + }) + } + + response := ListToolsResult{ + Tools: tools, + // We're not implementing pagination for now + NextCursor: "", + } + s.sendSuccessResponse(req.ID, response) +} + +// handleListResources handles the resources/list method +func (s *MCPServer) handleListResources(req JSONRPC2Request) { + if _, err := UnmarshalParamsLenient[ListResourcesParams](req.Params); err != nil { + s.sendErrorResponse(req.ID, ErrorCodeInvalidParams, "Invalid parameters", nil) + return + } + + resources := make([]Resource, 0, len(s.resourceCmds)) + for uri, cmd := range s.resourceCmds { + resources = append(resources, Resource{ + URI: uri, + Name: cmd.Name(), + Description: cmd.Short, + MimeType: "application/json", // Default MIME type + }) + } + + response := ListResourcesResult{ + Resources: resources, + // We're not implementing pagination for now + NextCursor: "", + } + s.sendSuccessResponse(req.ID, response) +} + +// handleListResourceTemplates handles the resources/templates/list method +func (s *MCPServer) handleListResourceTemplates(req JSONRPC2Request) { + _, err := UnmarshalParamsLenient[ListResourceTemplatesParams](req.Params) + if err != nil { + s.sendErrorResponse(req.ID, ErrorCodeInvalidParams, "Invalid parameters", nil) + return + } + + templates := make([]ResourceTemplate, 0, len(s.resourceTemplates)) + for uriTemplate, cmd := range s.resourceTemplates { + templates = append(templates, ResourceTemplate{ + URITemplate: uriTemplate, + Name: cmd.Name(), + Description: cmd.Short, + MimeType: "application/json", // Default MIME type + }) + } + + response := ListResourceTemplatesResult{ + ResourceTemplates: templates, + // We're not implementing pagination for now + NextCursor: "", + } + s.sendSuccessResponse(req.ID, response) +} + +// errReader is an io.Reader that never reads successfully, returning a predefined error. +type errReader string + +// Read implements io.Reader +func (r errReader) Read([]byte) (int, error) { + return 0, errors.New(string(r)) +} + +// Commands may attempt to read stdin. This error is returned on any attempted read from stdin from a command invocation. +var dontReadStdin = errReader("This command is attempting to read from stdin, which indicates that it is missing one or more required arguments.") + +// handleCallTool handles the tools/call method +func (s *MCPServer) handleCallTool(req JSONRPC2Request) { + params, err := UnmarshalParamsLenient[CallToolParams](req.Params) + if err != nil { + s.sendErrorResponse(req.ID, ErrorCodeInvalidParams, "Invalid parameters", nil) + return + } + + cmd, ok := s.toolCmds[params.Name] + if !ok { + s.sendErrorResponse(req.ID, ErrorCodeToolNotFound, fmt.Sprintf("Tool not found: %s", params.Name), nil) + return + } + + // Create a new invocation with captured stdout/stderr + var stdout, stderr strings.Builder + inv := cmd.Invoke() + inv.Stdin = dontReadStdin // MCP tools have no stdin. + inv.Stdout = &stdout + inv.Stderr = &stderr + + // Parse the arguments as a map and convert to command-line args + var args map[string]any + if err := json.Unmarshal(params.Arguments, &args); err != nil { + s.sendErrorResponse(req.ID, ErrorCodeInvalidParams, "Invalid arguments format", nil) + return + } + // Convert the arguments map to command-line args + var cmdArgs []string + + // Check for positional arguments prefix with `argN__` + deleteKeys := make([]string, 0) + for k, v := range args { + if strings.HasPrefix(k, "arg") && len(k) > 4 && k[3] >= '0' && k[3] <= '9' { + deleteKeys = append(deleteKeys, k) + switch val := v.(type) { + case string: + cmdArgs = append(cmdArgs, val) + case []any: + for _, item := range val { + cmdArgs = append(cmdArgs, fmt.Sprintf("%v", item)) + } + default: + cmdArgs = append(cmdArgs, fmt.Sprintf("%v", val)) + } + } + } + // Delete any of the positional argument keys so they don't get processed below. + for _, dk := range deleteKeys { + delete(args, dk) + } + + // Process remaining arguments as flags + for k, v := range args { + switch val := v.(type) { + case bool: + if val { + cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", k)) + } else { + cmdArgs = append(cmdArgs, fmt.Sprintf("--%s=false", k)) + } + case []any: + for _, item := range val { + cmdArgs = append(cmdArgs, fmt.Sprintf("--%s=%v", k, item)) + } + default: + cmdArgs = append(cmdArgs, fmt.Sprintf("--%s=%v", k, v)) + } + } + + // Finally, add any overridden flags at the tool level. + if toolFlags, ok := s.toolFlags[params.Name]; ok { + for _, flag := range toolFlags { + cmdArgs = append(cmdArgs, flag) + } + } + + inv.Args = cmdArgs + + // Run the command + err = inv.Run() + + // Prepare the response following MCP specification + var content []ToolContent + if stdout.Len() > 0 { + content = append(content, ToolContent{ + Type: "text", + Text: stdout.String(), + }) + } + + if stderr.Len() > 0 { + content = append(content, ToolContent{ + Type: "text", + Text: stderr.String(), + }) + } + + // Add error, if present. + if err != nil { + content = append(content, ToolContent{ + Type: "text", + Text: err.Error(), + }) + } + + // If still no content, add empty result + if len(content) == 0 { + content = append(content, ToolContent{ + Type: "text", + Text: "", + }) + } + + response := CallToolResult{ + Content: content, + IsError: err != nil, + } + + s.sendSuccessResponse(req.ID, response) +} + +// handleReadResource handles the resources/read method +func (s *MCPServer) handleReadResource(req JSONRPC2Request) { + params, err := UnmarshalParamsLenient[ReadResourceParams](req.Params) + if err != nil { + s.sendErrorResponse(req.ID, ErrorCodeInvalidParams, "Invalid parameters", nil) + return + } + + // First check if this is a direct resource URI match + cmd, ok := s.resourceCmds[params.URI] + if !ok { + // If not a direct match, check if it matches any URI template + for template, templateCmd := range s.resourceTemplates { + // Very basic template matching - would need more complex handling for real URI templates + if matched, _ := path.Match(template, params.URI); matched { + cmd = templateCmd + ok = true + break + } + } + + if !ok { + s.sendErrorResponse(req.ID, ErrorCodeResourceNotFound, fmt.Sprintf("Resource not found: %s", params.URI), nil) + return + } + } + + // Create a new invocation with captured stdout + var stdout strings.Builder + inv := cmd.Invoke() + inv.Stdout = &stdout + + // Run the command + if err := inv.Run(); err != nil { + s.sendErrorResponse(req.ID, ErrorCodeResourceUnavailable, fmt.Sprintf("Resource unavailable: %s", err.Error()), nil) + return + } + + // Create the response with the text content + response := ReadResourceResult{ + Contents: []ResourceContent{ + { + URI: params.URI, + MimeType: "application/json", // Assuming JSON by default + Text: stdout.String(), + }, + }, + } + + s.sendSuccessResponse(req.ID, response) +} + +// sendSuccessResponse sends a successful JSON-RPC response +func (s *MCPServer) sendSuccessResponse(id any, result any) { + resultBytes, err := json.Marshal(result) + if err != nil { + s.sendErrorResponse(id, ErrorCodeInternalError, "Failed to marshal result", nil) + return + } + + response := JSONRPC2Response{ + JSONRPC: "2.0", + ID: id, + Result: resultBytes, + } + + s.sendResponse(response) +} + +// generateJSONSchema generates a JSON Schema for a command's options +func (s *MCPServer) generateJSONSchema(cmd *Command) (json.RawMessage, error) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{}, + "required": []string{}, + } + + properties := schema["properties"].(map[string]any) + requiredList := schema["required"].([]string) + + // Add positional arguments based on the cmd usage. + if posArgs, err := PosArgsFromCmdUsage(cmd.Use); err != nil { + return nil, xerrors.Errorf("unable to process positional argument for command %q: %w", cmd.Name(), err) + } else { + for k, v := range posArgs { + properties[k] = v + } + } + + // Process each option in the command + for _, opt := range cmd.Options { + // Skip options that aren't exposed as flags + if opt.Flag == "" { + continue + } + // Skip hidden options + if opt.Hidden { + continue + } + + property := map[string]any{ + "description": opt.Description, + } + + // Determine JSON Schema type using pflag.Value.Type() + valueType := opt.Value.Type() + + switch valueType { + case "string": + property["type"] = "string" + // Special handling for file paths + if opt.Flag == "file-path" { + property["format"] = "path" + } + case "bool": + property["type"] = "boolean" + case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64": + property["type"] = "integer" // Use integer for whole numbers + case "float32", "float64": + property["type"] = "number" + case "ip", "ipMask", "ipNet", "count": // Specific pflag types + // Count is integer, others are strings + if valueType == "count" { + property["type"] = "integer" + } else { + property["type"] = "string" + } + case "duration": + property["type"] = "string" // Represent duration as string (e.g., "1h", "30m") + property["format"] = "duration" + // Handle slice types + case "stringSlice": + property["type"] = "array" + property["items"] = map[string]any{"type": "string"} + case "boolSlice": + property["type"] = "array" + property["items"] = map[string]any{"type": "boolean"} + case "intSlice", "int32Slice", "int64Slice", "uintSlice": + property["type"] = "array" + property["items"] = map[string]any{"type": "integer"} + case "float32Slice", "float64Slice": + property["type"] = "array" + property["items"] = map[string]any{"type": "number"} + case "ipSlice": + property["type"] = "array" + property["items"] = map[string]any{"type": "string"} + case "durationSlice": + property["type"] = "array" + property["items"] = map[string]any{"type": "string", "format": "duration"} + case "string-array", "stringArray", "stringToString", "stringToInt", "stringToInt64": // More pflag types + // stringArray is like stringSlice + // Map types are complex, represent as object for now + if valueType == "stringArray" { + property["type"] = "array" + property["items"] = map[string]any{"type": "string"} + } else { + property["type"] = "object" + property["additionalProperties"] = map[string]any{ + "type": "string", // Default to string value type for maps + } + if valueType == "stringToInt" || valueType == "stringToInt64" { + property["additionalProperties"] = map[string]any{ + "type": "integer", + } + } + } + // Handle custom serpent types + default: + // Check for known serpent custom types (Enum, EnumArray) + if enum, ok := opt.Value.(*Enum); ok { + property["type"] = "string" + property["enum"] = enum.Choices + } else if enumArray, ok := opt.Value.(*EnumArray); ok { + property["type"] = "array" + property["items"] = map[string]any{ + "type": "string", + "enum": enumArray.Choices, + } + } else { + // Fallback for unknown types + property["type"] = "string" + fmt.Fprintf(s.stderr, "Warning: Unknown pflag type '%s' for option '%s', defaulting to string\n", valueType, opt.Flag) + } + } + + // Add the property definition + properties[opt.Flag] = property + + // Add to required list if Required is true AND Default is not set + // (as per comment in option.go) + if opt.Required && opt.Default == "" { + requiredList = append(requiredList, opt.Flag) + } + } + + // Update required field only if it's not empty + if len(requiredList) > 0 { + schema["required"] = requiredList + } else { + // Remove the empty required array if no options are required + delete(schema, "required") + } + + return json.MarshalIndent(schema, "", " ") // Use MarshalIndent for readability +} + +// sendErrorResponse sends an error JSON-RPC response +func (s *MCPServer) sendErrorResponse(id any, code int, message string, data any) { + var dataBytes json.RawMessage + if data != nil { + var err error + dataBytes, err = json.Marshal(data) + if err != nil { + // If we can't marshal the data, just ignore it + dataBytes = nil + } + } + + response := JSONRPC2Response{ + JSONRPC: "2.0", + ID: id, + Error: &JSONRPC2Error{ + Code: code, + Message: message, + Data: dataBytes, + }, + } + + s.sendResponse(response) +} + +// sendResponse sends a JSON-RPC response to stdout +func (s *MCPServer) sendResponse(response JSONRPC2Response) { + responseBytes, err := json.Marshal(response) + if err != nil { + fmt.Fprintf(s.stderr, "Failed to marshal response: %v\n", err) + return + } + + fmt.Fprintln(s.stdout, string(responseBytes)) +} + +// unmarshalParamsLenient unmarshals JSON parameters in a lenient way, handling omitted optional parameters. +// If the raw JSON is empty, null, or "[]", it will use the default zero value for the target struct. +// If the JSON contains invalid data (not empty), the original error is returned. +// UnmarshalParamsLenient unmarshals JSON parameters in a lenient way, handling omitted optional parameters. +// If the raw JSON is empty, null, or "[]", it will return the zero value for type T. +// If the JSON contains invalid data (not empty), the original error is returned. +func UnmarshalParamsLenient[T any](data json.RawMessage) (T, error) { + var result T + + // If params is nil, empty, or just whitespace/empty array, use zero value + if len(data) == 0 || string(data) == "null" || string(data) == "{}" || string(data) == "[]" { + // Return the zero value of T + return result, nil + } + + // Otherwise, try to unmarshal normally + err := json.Unmarshal(data, &result) + if err != nil { + // Only return error if the JSON is not empty and contains invalid data + return result, err + } + + return result, nil +} + +// handleInitialize handles the initialize request +func (s *MCPServer) handleInitialize(req JSONRPC2Request) { + params, err := UnmarshalParamsLenient[InitializeParams](req.Params) + if err != nil { + s.sendErrorResponse(req.ID, ErrorCodeInvalidParams, "Invalid parameters", nil) + return + } + + // Negotiate protocol version + if params.ProtocolVersion != "" { + // For now, we just accept the client's protocol version + // In a real implementation, you would compare with supported versions + s.protocolVersion = params.ProtocolVersion + } + + // Determine if we have tools and resources + hasTools := len(s.toolCmds) > 0 + hasResources := len(s.resourceCmds) > 0 || len(s.resourceTemplates) > 0 + + // Create capabilities object following MCP 2025-03-26 spec + capabilities := map[string]any{} + + if hasTools { + capabilities["tools"] = map[string]any{ + "listChanged": false, // We don't support list change notifications yet + } + } + + if hasResources { + capabilities["resources"] = map[string]any{ + "listChanged": false, // We don't support list change notifications yet + "subscribe": false, // We don't support subscriptions yet + } + } + + // Create the response + result := InitializeResult{ + ProtocolVersion: s.protocolVersion, + ServerInfo: struct { + Name string `json:"name"` + Version string `json:"version"` + }{ + Name: "serpent-mcp", + Version: "1.0.0", + }, + Capabilities: capabilities, + } + + s.initialized = true + + // Send the response + s.sendSuccessResponse(req.ID, result) +} + +// handleInitialized handles the initialized notification +func (s *MCPServer) handleInitialized(req JSONRPC2Request) { + // Mark the server as initialized + s.initialized = true + + // No response needed for a notification + // But we'll send a success response anyway since our request has an ID + if req.ID != nil { + s.sendSuccessResponse(req.ID, struct{}{}) + } +} + +// MCPCommand creates a generic command that can run any serpent command as an MCP server +func MCPCommand() *Command { + return &Command{ + Use: "mcp [command]", + Short: "Run a command as an MCP server", + Long: `Run a command as a Model Context Protocol (MCP) server over stdio. + +This command allows any serpent command to be exposed as an MCP server, which can +provide tools and resources to MCP clients. The server communicates using JSON-RPC 2.0 +over stdin/stdout. + +If a command name is provided, that specific command will be run as an MCP server. +Otherwise, the root command will be used. + +Commands with a Tool field set can be invoked as MCP tools. +Commands with a Resource field set can be accessed as MCP resources. +Commands with neither Tool nor Resource set will not be accessible via MCP.`, + Handler: func(inv *Invocation) error { + rootCmd := inv.Command + if rootCmd.Parent != nil { + // Find the root command + for rootCmd.Parent != nil { + rootCmd = rootCmd.Parent + } + } + + // If a command name is provided, use that as the root + if len(inv.Args) > 0 { + cmdName := strings.Join(inv.Args, " ") + cmd := DefaultCommandFinder(rootCmd, cmdName) + if cmd == nil { + return xerrors.Errorf("command not found: %s", cmdName) + } + rootCmd = cmd + } + + // Create and run the MCP server + server := NewMCPServer(rootCmd, inv.Stdin, inv.Stdout, inv.Stderr) + return server.Run(inv.Context()) + }, + } +} + +// PosArgsFromCmdUsage attempts to process a 'usage' string into a set of +// arguments for display as tool parameters. +// Example: the usage string `foo [flags] [baz] [razzle|dazzle]` +// defines three arguments for the `foo` command: +// - bar (required) +// - baz (optional) +// - the string `razzle` XOR `dazzle` (optional) +// +// The expected output of the above is as follows: +// +// { +// "arg1:bar": { +// "type": "string", +// "description": "required argument", +// }, +// "arg2:baz": { +// "type": "string", +// "description": "optional argument", +// }, +// "arg3:razzle_dazzle": { +// "type": "string", +// "enum": ["razzle", "dazzle"] +// }, +// } +// +// The usage string is processed given the following assumptions: +// 1. The first non-whitespace string of usage is the name of the command +// and will be skipped. +// 2. The pseudo-argument specifier [flags] will also be skipped, if present. +// 3. Argument specifiers enclosed by [square brackets] are considered optional. +// 4. All other argument specifiers are considered required. +// 5. Invidiual argument specifiers are separated by a single whitespace character. +// Argument specifiers that contain a space are considered invalid (e.g. `[foo bar]`) +// +// Variadic arguments [arg...] are treated as a single argument. +func PosArgsFromCmdUsage(usage string) (map[string]any, error) { + if len(usage) == 0 { + return nil, xerrors.Errorf("usage may not be empty") + } + + // Step 1: preprocessing. Skip the first token. + parts := strings.Fields(usage) + if len(parts) < 2 { + return map[string]any{}, nil + } + parts = parts[1:] + // Skip [flags], if present. + parts = slices.DeleteFunc(parts, func(s string) bool { + return s == "[flags]" + }) + + result := make(map[string]any, len(parts)) + + // Process each argument token + for i, part := range parts { + argIndex := i + 1 + argKey := fmt.Sprintf("arg%d__", argIndex) + + // Check for unbalanced brackets in the part. + // This catches cases like "command [flags] [a" or "command [flags] a b [c | d]" + // which would be split into multiple tokens by strings.Fields() + openSquare := strings.Count(part, "[") + closeSquare := strings.Count(part, "]") + openAngle := strings.Count(part, "<") + closeAngle := strings.Count(part, ">") + openBrace := strings.Count(part, "{") + closeBrace := strings.Count(part, "}") + + if openSquare != closeSquare { + return nil, xerrors.Errorf("malformed usage: unbalanced square bracket at %q", part) + } else if openAngle != closeAngle { + return nil, xerrors.Errorf("malformed usage: unbalanced angle bracket at %q", part) + } else if openBrace != closeBrace { + return nil, xerrors.Errorf("malformed usage: unbalanced brace at %q", part) + } + + // Determine if the argument is optional (enclosed in square brackets) + isOptional := openSquare > 0 + cleanName := strings.Trim(part, "[]{}<>.") + description := "required argument" + if isOptional { + description = "optional argument" + } + + argVal := map[string]any{ + "type": "string", + "description": description, + // "required": !isOptional, + } + + keyName := cleanName + // If an argument specifier contains a pipe, treat it as an enum. + if strings.Contains(cleanName, "|") { + choices := strings.Split(cleanName, "|") + // Create a name by joining alternatives with underscores + keyName = strings.Join(choices, "_") + argVal["enum"] = choices + } + argKey += keyName + result[argKey] = argVal + } + + return result, nil +} diff --git a/mcp_test.go b/mcp_test.go new file mode 100644 index 0000000..b8c4b4d --- /dev/null +++ b/mcp_test.go @@ -0,0 +1,451 @@ +package serpent + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestToolAndResourceFields(t *testing.T) { + // Test that a command with neither Tool nor Resource is not MCP-enabled + cmd := &Command{ + Use: "regular", + Short: "Regular command with no MCP fields", + } + if cmd.IsMCPEnabled() { + t.Error("Command without Tool or Resource should not be MCP-enabled") + } + + // Test that a command with Tool is MCP-enabled + toolCmd := &Command{ + Use: "tool-cmd", + Short: "Command with Tool field", + Tool: "example-tool", + } + if !toolCmd.IsMCPEnabled() { + t.Error("Command with Tool should be MCP-enabled") + } + + // Test that a command with Resource is MCP-enabled + resourceCmd := &Command{ + Use: "resource-cmd", + Short: "Command with Resource field", + Resource: "example-resource", + } + if !resourceCmd.IsMCPEnabled() { + t.Error("Command with Resource should be MCP-enabled") + } + + // Test that a command cannot have both Tool and Resource + invalidCmd := &Command{ + Use: "invalid-cmd", + Short: "Command with both Tool and Resource", + Tool: "example-tool", + Resource: "example-resource", + } + + if err := invalidCmd.init(); err == nil { + t.Error("Command with both Tool and Resource should fail initialization") + } +} + +func TestMCPServerSetup(t *testing.T) { + // Create a root command with subcommands having Tool and Resource + root := &Command{ + Use: "root", + Short: "Root command", + } + + toolCmd := &Command{ + Use: "tool-cmd", + Short: "Tool command", + Tool: "test-tool", + Handler: func(inv *Invocation) error { + fmt.Fprintln(inv.Stdout, "Tool executed!") + return nil + }, + } + + resourceCmd := &Command{ + Use: "resource-cmd", + Short: "Resource command", + Resource: "test-resource", + Handler: func(inv *Invocation) error { + fmt.Fprintln(inv.Stdout, `{"result": "Resource data"`) + return nil + }, + } + + templatedResourceCmd := &Command{ + Use: "templated-cmd", + Short: "Templated resource command", + Resource: "test/{param}", + Handler: func(inv *Invocation) error { + fmt.Fprintln(inv.Stdout, `{"template": "Resource template"}`) + return nil + }, + } + + root.AddSubcommands(toolCmd, resourceCmd, templatedResourceCmd) + + // Create a server with the root command + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + server := NewMCPServer(root, stdin, stdout, stderr) + + // Check if tools were indexed correctly + if len(server.toolCmds) != 1 { + t.Errorf("Expected 1 tool command, got %d", len(server.toolCmds)) + } + + if len(server.resourceCmds) != 1 { + t.Errorf("Expected 1 resource command, got %d", len(server.resourceCmds)) + } + + if len(server.resourceTemplates) != 1 { + t.Errorf("Expected 1 resource template, got %d", len(server.resourceTemplates)) + } + + if cmd, ok := server.toolCmds["test-tool"]; !ok || cmd != toolCmd { + t.Error("Tool command not properly indexed") + } + + if cmd, ok := server.resourceCmds["test-resource"]; !ok || cmd != resourceCmd { + t.Error("Resource command not properly indexed") + } + + if cmd, ok := server.resourceTemplates["test/{param}"]; !ok || cmd != templatedResourceCmd { + t.Error("Resource template command not properly indexed") + } +} + +func TestJSONSchemaGeneration(t *testing.T) { + // Create a command with various option types + cmd := &Command{ + Use: "test-schema", + Short: "Command for testing schema generation", + Options: OptionSet{ + { + Flag: "string-flag", + Description: "A string flag", + Value: StringOf(new(string)), + }, + { + Flag: "bool-flag", + Description: "A boolean flag", + Value: BoolOf(new(bool)), + Required: true, + }, + { + Flag: "file-path", + Description: "A file path", + Value: StringOf(new(string)), + }, + { + Flag: "enum-choice", + Description: "An enum choice", + Value: EnumOf(new(string), "option1", "option2", "option3"), + }, + }, + } + + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + server := NewMCPServer(cmd, stdin, stdout, stderr) + schema, err := server.generateJSONSchema(cmd) + + if err != nil { + t.Fatalf("Failed to generate schema: %v", err) + } + + var schemaObj map[string]interface{} + if err := json.Unmarshal(schema, &schemaObj); err != nil { + t.Fatalf("Generated schema is not valid JSON: %v", err) + } + + // Validate schema structure + if schemaObj["type"] != "object" { + t.Error("Schema should have type 'object'") + } + + properties, ok := schemaObj["properties"].(map[string]interface{}) + if !ok { + t.Fatal("Schema should have properties map") + } + + // Check if required fields are properly identified + required, ok := schemaObj["required"].([]interface{}) + if !ok { + t.Fatal("Schema should have required array") + } + + foundRequired := false + for _, r := range required { + if r == "bool-flag" { + foundRequired = true + break + } + } + if !foundRequired { + t.Error("Required flag not found in required list") + } + + // Check if properties have correct types + filePathProp, ok := properties["file-path"].(map[string]interface{}) + if !ok { + t.Fatal("file-path property not found or not an object") + } + if filePathProp["format"] != "path" { + t.Errorf("file-path should have format 'path', got %v", filePathProp["format"]) + } + + enumProp, ok := properties["enum-choice"].(map[string]interface{}) + if !ok { + t.Fatal("enum-choice property not found or not an object") + } + + enumValues, ok := enumProp["enum"].([]interface{}) + if !ok || len(enumValues) != 3 { + t.Errorf("enum-choice should have enum array with 3 values, got %v", enumProp["enum"]) + } +} + +func TestMCPServerRun(t *testing.T) { + // Create a simple command for testing + cmd := &Command{ + Use: "test", + Short: "Test command", + } + + toolCmd := &Command{ + Use: "tool-cmd", + Short: "Tool command", + Tool: "test-tool", + Handler: func(inv *Invocation) error { + fmt.Fprintln(inv.Stdout, "Tool executed!") + return nil + }, + } + + cmd.AddSubcommands(toolCmd) + + // Setup the server with buffers + input := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","clientInfo":{"name":"test-client","version":"1.0.0"},"capabilities":{}}} +{"jsonrpc":"2.0","id":2,"method":"notifications/initialized","params":{}} +{"jsonrpc":"2.0","id":3,"method":"tools/list","params":{}} +` + stdin := strings.NewReader(input) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + server := NewMCPServer(cmd, stdin, stdout, stderr) + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + // Run the server (it will stop when stdin is drained or context is cancelled) + err := server.Run(ctx) + if err != nil { + t.Fatalf("Server run failed: %v", err) + } + + // Check output + output := stdout.String() + + // Verify we got the expected responses + if !strings.Contains(output, `"protocolVersion"`) { + t.Error("Missing protocol version in initialize response") + } + + if !strings.Contains(output, `"tools":[`) { + t.Error("Missing tools list in response") + } + + if !strings.Contains(output, `"test-tool"`) { + t.Error("Tool name not found in response") + } +} + +func TestLenientParameterHandling(t *testing.T) { + // Create a simple command for testing + cmd := &Command{ + Use: "test", + Short: "Test command", + } + + toolCmd := &Command{ + Use: "tool-cmd", + Short: "Tool command", + Tool: "test-tool", + Handler: func(inv *Invocation) error { + fmt.Fprintln(inv.Stdout, "Tool executed!") + return nil + }, + } + + cmd.AddSubcommands(toolCmd) + + // Test the unmarshalParamsLenient function directly + testCases := []struct { + name string + params json.RawMessage + expectErr bool + }{ + { + name: "Missing params (nil)", + params: nil, + expectErr: false, + }, + { + name: "Empty params object", + params: json.RawMessage(`{}`), + expectErr: false, + }, + { + name: "Null params", + params: json.RawMessage(`null`), + expectErr: false, + }, + { + name: "Empty params array", + params: json.RawMessage(`[]`), + expectErr: false, + }, + { + name: "Partial params", + params: json.RawMessage(`{"protocolVersion":"2025-03-26"}`), + expectErr: false, + }, + { + name: "Invalid params format", + params: json.RawMessage(`"invalid"`), + expectErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if _, err := UnmarshalParamsLenient[InitializeParams](tc.params); err != nil { + if !tc.expectErr { + t.Errorf("Expected no error but got: %v", err) + } + return + } + + if tc.expectErr { + t.Error("Expected error but got none") + } + }) + } +} + +func TestPosArgsFromCmdUsage(t *testing.T) { + for _, tc := range []struct { + input string + expected map[string]any + expectedError string + }{ + { + input: "", + expectedError: "usage may not be empty", + }, + { + input: "command", + expected: map[string]any{}, + }, + { + input: "[flags]", + expected: map[string]any{}, + }, + { + input: "command [flags]", + expected: map[string]any{}, + }, + { + input: "command [flags] a [b] [] [e...]", + expected: map[string]any{ + "arg1__a": map[string]any{ + "description": "required argument", + "type": "string", + }, + "arg2__b": map[string]any{ + "description": "optional argument", + "type": "string", + }, + "arg3__c": map[string]any{ + "description": "required argument", + "type": "string", + }, + "arg4__d": map[string]any{ + "description": "optional argument", + "type": "string", + }, + "arg5__e": map[string]any{ + "description": "optional argument", + "type": "string", + }, + }, + }, + { + input: "command [flags] [c|d]", + expected: map[string]any{ + "arg1__a_b": map[string]any{ + "description": "required argument", + "enum": []string{"a", "b"}, + "type": "string", + }, + "arg2__c_d": map[string]any{ + "description": "optional argument", + "enum": []string{"c", "d"}, + "type": "string", + }, + }, + }, + { + input: "command [flags] ", + expectedError: "malformed usage", + }, + { + input: "command [flags] [c | d]", + expectedError: "malformed usage", + }, + { + input: "command [flags] {e f}", + expectedError: "malformed usage", + }, + } { + actual, err := PosArgsFromCmdUsage(tc.input) + if tc.expectedError == "" { + if err != nil { + t.Errorf("expected no error from %q, got %v", tc.input, err) + continue + } + if diff := cmp.Diff(tc.expected, actual); diff != "" { + t.Errorf("unexpected diff (-want +got):\n%s", diff) + continue + } + } else { + if err == nil { + t.Errorf("expected error containing '%s' from input %q, got no error", tc.expectedError, tc.input) + continue + } + if !strings.Contains(err.Error(), tc.expectedError) { + t.Errorf("expected error containing '%s' from input %q, got '%s'", tc.expectedError, tc.input, err.Error()) + } + if len(actual) != 0 { + t.Errorf("expected empty result on error, got %v", actual) + } + } + } +}