Skip to content

Commit dae1845

Browse files
committed
feat: update client options to use optional fields
1 parent 1c178ae commit dae1845

File tree

7 files changed

+87
-34
lines changed

7 files changed

+87
-34
lines changed

.golangci.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
severity:
2+
default: error
3+
rules:
4+
- linters:
5+
- errcheck
6+
- unused
7+
severity: info
8+
formatters:
9+
enable:
10+
- gci
11+
- gofmt
12+
- gofumpt
13+
- goimports
14+
version: "2"

Makefile

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,15 @@ buildLinuxX86:
1515

1616
.PHONY: buildImage
1717
buildImage:
18-
docker buildx build --platform=linux/amd64,linux/arm64 -t ghcr.io/tbxark/map-proxy:latest . --push --provenance=false
18+
docker buildx build --platform=linux/amd64,linux/arm64 -t ghcr.io/tbxark/map-proxy:latest . --push --provenance=false
19+
20+
.PHONY: lint
21+
lint:
22+
golangci-lint run
23+
24+
.PHONY: format
25+
format:
26+
golangci-lint fmt
27+
golangci-lint run --fix
28+
go fmt ./...
29+
go mod tidy

client.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"log"
8+
"time"
9+
710
"github.com/mark3labs/mcp-go/client"
811
"github.com/mark3labs/mcp-go/client/transport"
912
"github.com/mark3labs/mcp-go/mcp"
1013
"github.com/mark3labs/mcp-go/server"
11-
"log"
12-
"time"
1314
)
1415

1516
type Client struct {
@@ -114,11 +115,12 @@ func (c *Client) addToMCPServer(ctx context.Context, clientInfo mcp.Implementati
114115
func (c *Client) startPingTask(ctx context.Context) {
115116
ticker := time.NewTicker(30 * time.Second)
116117
defer ticker.Stop()
118+
PingLoop:
117119
for {
118120
select {
119121
case <-ctx.Done():
120122
log.Printf("<%s> Context done, stopping ping", c.name)
121-
break
123+
break PingLoop
122124
case <-ticker.C:
123125
_ = c.client.Ping(ctx)
124126
}
@@ -143,7 +145,7 @@ func (c *Client) addToolsToServer(ctx context.Context, mcpServer *server.MCPServ
143145
if tools.NextCursor == "" {
144146
break
145147
}
146-
toolsRequest.PaginatedRequest.Params.Cursor = tools.NextCursor
148+
toolsRequest.Params.Cursor = tools.NextCursor
147149
}
148150
return nil
149151
}
@@ -166,7 +168,7 @@ func (c *Client) addPromptsToServer(ctx context.Context, mcpServer *server.MCPSe
166168
if prompts.NextCursor == "" {
167169
break
168170
}
169-
promptsRequest.PaginatedRequest.Params.Cursor = prompts.NextCursor
171+
promptsRequest.Params.Cursor = prompts.NextCursor
170172
}
171173
return nil
172174
}
@@ -195,7 +197,7 @@ func (c *Client) addResourcesToServer(ctx context.Context, mcpServer *server.MCP
195197
if resources.NextCursor == "" {
196198
break
197199
}
198-
resourcesRequest.PaginatedRequest.Params.Cursor = resources.NextCursor
200+
resourcesRequest.Params.Cursor = resources.NextCursor
199201

200202
}
201203
return nil
@@ -225,7 +227,7 @@ func (c *Client) addResourceTemplatesToServer(ctx context.Context, mcpServer *se
225227
if resourceTemplates.NextCursor == "" {
226228
break
227229
}
228-
resourceTemplatesRequest.PaginatedRequest.Params.Cursor = resourceTemplates.NextCursor
230+
resourceTemplatesRequest.Params.Cursor = resourceTemplates.NextCursor
229231
}
230232
return nil
231233
}
@@ -249,7 +251,7 @@ func newMCPServer(name, version, baseURL string, clientConfig *MCPClientConfigV2
249251
server.WithRecovery(),
250252
}
251253

252-
if *clientConfig.Options.LogEnabled {
254+
if clientConfig.Options.LogEnabled.OrElse(false) {
253255
serverOpts = append(serverOpts, server.WithLogging())
254256
}
255257
mcpServer := server.NewMCPServer(

config.go

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package main
33
import (
44
"encoding/json"
55
"errors"
6-
"github.com/TBXark/confstore"
76
"time"
7+
8+
"github.com/TBXark/confstore"
9+
"github.com/TBXark/optional-go"
810
)
911

1012
type StdioMCPClientConfig struct {
@@ -81,9 +83,9 @@ func parseMCPClientConfigV1(conf *MCPClientConfigV1) (any, error) {
8183
// ---- V2 ----
8284

8385
type OptionsV2 struct {
84-
PanicIfInvalid *bool `json:"panicIfInvalid,omitempty"`
85-
LogEnabled *bool `json:"logEnabled,omitempty"`
86-
AuthTokens []string `json:"authTokens,omitempty"`
86+
PanicIfInvalid optional.Field[bool] `json:"panicIfInvalid,omitempty"`
87+
LogEnabled optional.Field[bool] `json:"logEnabled,omitempty"`
88+
AuthTokens []string `json:"authTokens,omitempty"`
8789
}
8890

8991
type MCPProxyConfigV2 struct {
@@ -178,11 +180,8 @@ func load(path string) (*Config, error) {
178180
if cErr != nil {
179181
continue
180182
}
181-
falseVal := false
182183
options := &OptionsV2{
183-
PanicIfInvalid: &falseVal,
184-
LogEnabled: &falseVal,
185-
AuthTokens: clientConfig.AuthTokens,
184+
AuthTokens: clientConfig.AuthTokens,
186185
}
187186
if conf.DeprecatedServerV1 != nil && len(conf.DeprecatedServerV1.GlobalAuthTokens) > 0 {
188187
options.AuthTokens = append(options.AuthTokens, conf.DeprecatedServerV1.GlobalAuthTokens...)
@@ -217,11 +216,7 @@ func load(path string) (*Config, error) {
217216
return nil, errors.New("mcpProxy is required")
218217
}
219218
if conf.McpProxy.Options == nil {
220-
falseVal := false
221-
conf.McpProxy.Options = &OptionsV2{
222-
PanicIfInvalid: &falseVal,
223-
LogEnabled: &falseVal,
224-
}
219+
conf.McpProxy.Options = &OptionsV2{}
225220
}
226221
for _, clientConfig := range conf.McpServers {
227222
if clientConfig.Options == nil {
@@ -230,10 +225,10 @@ func load(path string) (*Config, error) {
230225
if clientConfig.Options.AuthTokens == nil {
231226
clientConfig.Options.AuthTokens = conf.McpProxy.Options.AuthTokens
232227
}
233-
if clientConfig.Options.PanicIfInvalid == nil {
228+
if !clientConfig.Options.PanicIfInvalid.Present() {
234229
clientConfig.Options.PanicIfInvalid = conf.McpProxy.Options.PanicIfInvalid
235230
}
236-
if clientConfig.Options.LogEnabled == nil {
231+
if !clientConfig.Options.LogEnabled.Present() {
237232
clientConfig.Options.LogEnabled = conf.McpProxy.Options.LogEnabled
238233
}
239234
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ toolchain go1.23.7
66

77
require (
88
github.com/TBXark/confstore v0.0.3
9+
github.com/TBXark/optional-go v0.0.1
910
github.com/mark3labs/mcp-go v0.23.1
1011
golang.org/x/sync v0.13.0
1112
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
github.com/TBXark/confstore v0.0.2 h1:rmeho8xuUeJriJTC3q9iVuR7BJb4Fibnc3FiGKJNc9w=
2-
github.com/TBXark/confstore v0.0.2/go.mod h1:TOxM19Snt9wT02PJzyz66sgq7sWhr4AFzPEZIKvVnTE=
31
github.com/TBXark/confstore v0.0.3 h1:d+djx1k6lh6E9UZsPKXpigYZk0z1j5KoMuhwgb2LLMg=
42
github.com/TBXark/confstore v0.0.3/go.mod h1:TOxM19Snt9wT02PJzyz66sgq7sWhr4AFzPEZIKvVnTE=
3+
github.com/TBXark/optional-go v0.0.1 h1:ZIeoYfA7UWcpx+Otxdc0f0tvfSDkJuJVYmjnLfr2P8I=
4+
github.com/TBXark/optional-go v0.0.1/go.mod h1:skpoGkocQNq/IRct1T2rgwSrXEy1nUY+Sz28r68t4yE=
55
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
66
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
77
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=

http.go

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"github.com/mark3labs/mcp-go/mcp"
8-
"golang.org/x/sync/errgroup"
97
"log"
108
"net/http"
119
"os"
1210
"os/signal"
1311
"strings"
1412
"syscall"
1513
"time"
14+
15+
"github.com/mark3labs/mcp-go/mcp"
16+
"golang.org/x/sync/errgroup"
1617
)
1718

1819
type MiddlewareFunc func(http.Handler) http.Handler
@@ -33,9 +34,7 @@ func newAuthMiddleware(tokens []string) MiddlewareFunc {
3334
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3435
if len(tokens) != 0 {
3536
token := r.Header.Get("Authorization")
36-
if strings.HasPrefix(token, "Bearer ") {
37-
token = strings.TrimPrefix(token, "Bearer ")
38-
}
37+
token = strings.TrimSpace(strings.TrimPrefix(token, "Bearer "))
3938
if token == "" {
4039
http.Error(w, "Unauthorized", http.StatusUnauthorized)
4140
return
@@ -50,8 +49,30 @@ func newAuthMiddleware(tokens []string) MiddlewareFunc {
5049
}
5150
}
5251

53-
func startHTTPServer(config *Config) {
52+
func loggerMiddleware(prefix string) MiddlewareFunc {
53+
return func(next http.Handler) http.Handler {
54+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
55+
log.Printf("<%s> Request [%s] %s", prefix, r.Method, r.URL.Path)
56+
next.ServeHTTP(w, r)
57+
})
58+
}
59+
}
5460

61+
func recoverMiddleware(prefix string) MiddlewareFunc {
62+
return func(next http.Handler) http.Handler {
63+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
64+
defer func() {
65+
if err := recover(); err != nil {
66+
log.Printf("<%s> Recovered from panic: %v", prefix, err)
67+
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
68+
}
69+
}()
70+
next.ServeHTTP(w, r)
71+
})
72+
}
73+
}
74+
75+
func startHTTPServer(config *Config) {
5576
ctx, cancel := context.WithCancel(context.Background())
5677
defer cancel()
5778

@@ -77,13 +98,22 @@ func startHTTPServer(config *Config) {
7798
addErr := mcpClient.addToMCPServer(ctx, info, server.mcpServer)
7899
if addErr != nil {
79100
log.Printf("<%s> Failed to add client to server: %v", name, addErr)
80-
if *clientConfig.Options.PanicIfInvalid {
101+
if clientConfig.Options.PanicIfInvalid.OrElse(false) {
81102
return addErr
82103
}
83104
return nil
84105
}
85106
log.Printf("<%s> Connected", name)
86-
httpMux.Handle(fmt.Sprintf("/%s/", name), chainMiddleware(server.sseServer, newAuthMiddleware(server.tokens)))
107+
108+
middlewares := make([]MiddlewareFunc, 0)
109+
middlewares = append(middlewares, recoverMiddleware(name))
110+
if clientConfig.Options.LogEnabled.OrElse(false) {
111+
middlewares = append(middlewares, loggerMiddleware(name))
112+
}
113+
if len(clientConfig.Options.AuthTokens) > 0 {
114+
middlewares = append(middlewares, newAuthMiddleware(clientConfig.Options.AuthTokens))
115+
}
116+
httpMux.Handle(fmt.Sprintf("/%s/", name), chainMiddleware(server.sseServer, middlewares...))
87117
httpServer.RegisterOnShutdown(func() {
88118
log.Printf("<%s> Shutting down", name)
89119
_ = mcpClient.Close()

0 commit comments

Comments
 (0)