Skip to content

Commit ad23ad6

Browse files
authored
exp/api: Add accepted msg type validation to handler (#1750)
* exp: Add accepted msg type validation to handler Signed-off-by: Saswata Mukherjee <[email protected]> * Apply feedback Signed-off-by: Saswata Mukherjee <[email protected]> --------- Signed-off-by: Saswata Mukherjee <[email protected]>
1 parent 248c3f7 commit ad23ad6

File tree

3 files changed

+28
-10
lines changed

3 files changed

+28
-10
lines changed

exp/api/remote/remote_api.go

+17-4
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,9 @@ type writeStorage interface {
377377
}
378378

379379
type handler struct {
380-
store writeStorage
381-
opts handlerOpts
380+
store writeStorage
381+
acceptedMessageTypes MessageTypes
382+
opts handlerOpts
382383
}
383384

384385
type handlerOpts struct {
@@ -455,15 +456,20 @@ func SnappyDecompressorMiddleware(logger *slog.Logger) func(http.Handler) http.H
455456

456457
// NewHandler returns HTTP handler that receives Remote Write 2.0
457458
// protocol https://prometheus.io/docs/specs/remote_write_spec_2_0/.
458-
func NewHandler(store writeStorage, opts ...HandlerOption) http.Handler {
459+
func NewHandler(store writeStorage, acceptedMessageTypes MessageTypes, opts ...HandlerOption) http.Handler {
459460
o := handlerOpts{
460461
logger: slog.New(nopSlogHandler{}),
461462
middlewares: []func(http.Handler) http.Handler{SnappyDecompressorMiddleware(slog.New(nopSlogHandler{}))},
462463
}
463464
for _, opt := range opts {
464465
opt(&o)
465466
}
466-
h := &handler{opts: o, store: store}
467+
468+
h := &handler{
469+
opts: o,
470+
store: store,
471+
acceptedMessageTypes: acceptedMessageTypes,
472+
}
467473

468474
// Apply all middlewares in order
469475
var handler http.Handler = h
@@ -524,6 +530,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
524530
return
525531
}
526532

533+
if !h.acceptedMessageTypes.Contains(msgType) {
534+
err := fmt.Errorf("%v protobuf message is not accepted by this server; only accepts %v", msgType, h.acceptedMessageTypes.String())
535+
h.opts.logger.Error("Unaccepted message type", "msgType", msgType, "err", err)
536+
http.Error(w, err.Error(), http.StatusUnsupportedMediaType)
537+
return
538+
}
539+
527540
writeResponse, storeErr := h.store.Store(r.Context(), msgType, r)
528541

529542
// Set required X-Prometheus-Remote-Write-Written-* response headers, in all cases, alongwith any user-defined headers.

exp/api/remote/remote_api_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func TestRemoteAPI_Write_WithHandler(t *testing.T) {
147147
t.Run("success", func(t *testing.T) {
148148
tLogger := slog.Default()
149149
mStore := &mockStorage{}
150-
srv := httptest.NewServer(NewHandler(mStore, WithHandlerLogger(tLogger)))
150+
srv := httptest.NewServer(NewHandler(mStore, MessageTypes{WriteV2MessageType}, WithHandlerLogger(tLogger)))
151151
t.Cleanup(srv.Close)
152152

153153
client, err := NewAPI(srv.URL,
@@ -182,7 +182,7 @@ func TestRemoteAPI_Write_WithHandler(t *testing.T) {
182182
mockErr: errors.New("storage error"),
183183
mockCode: &mockCode,
184184
}
185-
srv := httptest.NewServer(NewHandler(mStore, WithHandlerLogger(tLogger)))
185+
srv := httptest.NewServer(NewHandler(mStore, MessageTypes{WriteV2MessageType}, WithHandlerLogger(tLogger)))
186186
t.Cleanup(srv.Close)
187187

188188
client, err := NewAPI(srv.URL,

exp/api/remote/remote_headers.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"errors"
1818
"fmt"
1919
"net/http"
20+
"slices"
2021
"strconv"
2122
"strings"
2223
)
@@ -59,24 +60,28 @@ func (n WriteMessageType) Validate() error {
5960
case WriteV1MessageType, WriteV2MessageType:
6061
return nil
6162
default:
62-
return fmt.Errorf("unknown type for remote write protobuf message %v, supported: %v", n, messageTypes{WriteV1MessageType, WriteV2MessageType}.String())
63+
return fmt.Errorf("unknown type for remote write protobuf message %v, supported: %v", n, MessageTypes{WriteV1MessageType, WriteV2MessageType}.String())
6364
}
6465
}
6566

66-
type messageTypes []WriteMessageType
67+
type MessageTypes []WriteMessageType
6768

68-
func (m messageTypes) Strings() []string {
69+
func (m MessageTypes) Strings() []string {
6970
ret := make([]string, 0, len(m))
7071
for _, typ := range m {
7172
ret = append(ret, string(typ))
7273
}
7374
return ret
7475
}
7576

76-
func (m messageTypes) String() string {
77+
func (m MessageTypes) String() string {
7778
return strings.Join(m.Strings(), ", ")
7879
}
7980

81+
func (m MessageTypes) Contains(mType WriteMessageType) bool {
82+
return slices.Contains(m, mType)
83+
}
84+
8085
var contentTypeHeaders = map[WriteMessageType]string{
8186
WriteV1MessageType: appProtoContentType, // Also application/x-protobuf;proto=prometheus.WriteRequest but simplified for compatibility with 1.x spec.
8287
WriteV2MessageType: appProtoContentType + ";proto=io.prometheus.write.v2.Request",

0 commit comments

Comments
 (0)