Skip to content

Commit 39cba1c

Browse files
tech-guru42pflynn-virtrudmihalcik-virtrubiscoe916
committed
feat(core): adds authn interceptor to grpc server (#296)
* feat: verify and validate access tokens on service calls * add autnInterceptor * save * save progress * remove testIDP var * cleanup * comments * comment * unit tests for access token verification and validation * updated configuration docs * registered authn check as handler in mux chain * rename authN config field and remove left over log line * move authn to internal * fix authn test * only set issuer in platform welknown config * fix loading authn with handler * fix grpccurl step * fix healthcheck grpccurl call * didn't save * disable auth for service extension test * need to set mux to handler on server start * try nohub * try just go start to see errors * pause on starting opentdf * disable auth in example config * Update internal/auth/authn.go Co-authored-by: Paul Flynn <[email protected]> * Update internal/server/server.go Co-authored-by: Dave Mihalcik <[email protected]> * fix lint errors --------- Co-authored-by: Paul Flynn <[email protected]> Co-authored-by: Dave Mihalcik <[email protected]> Co-authored-by: Tyler Biscoe <[email protected]>
1 parent 5d90ffd commit 39cba1c

File tree

13 files changed

+699
-36
lines changed

13 files changed

+699
-36
lines changed

.github/workflows/checks.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ jobs:
114114
wait-for: 90s
115115
- run: go install github.com/fullstorydev/grpcurl/cmd/[email protected]
116116
- run: grpcurl -plaintext localhost:9000 list
117-
- run: grpcurl -plaintext localhost:9000 list policy.attributes.AttributesService
117+
- run: grpcurl -plaintext localhost:9000 grpc.health.v1.Health.Check
118118

119119
image:
120120
name: image build

docs/configuration.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ The server configuration is used to define how the application runs its server.
4242
- `enabled`: Enable tls. `(default: false)`
4343
- `cert`: The path to the tls certificate.
4444
- `key`: The path to the tls key.
45+
- `auth`: The configuration for your trusted IDP.
46+
- `enabled`: Enable authentication. `(default: true)`
47+
- `audience`: The audience for the IDP.
48+
- `issuer`: The issuer for the IDP.
49+
- `clients`: A list of client id's that are allowed
4550

4651
Example:
4752

@@ -56,6 +61,13 @@ server:
5661
enabled: true
5762
cert: /path/to/cert
5863
key: /path/to/key
64+
auth:
65+
enabled: true
66+
audience: https://example.com
67+
issuer: https://example.com
68+
clients:
69+
- client_id
70+
- client_id2
5971
```
6072

6173
## Database Configuration

example-opentdf.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ services:
3838
- "msExchMailboxGuid"
3939
- "msExchMailboxSecurityDescriptor"
4040
server:
41+
auth:
42+
enabled: false
43+
audience: "opentdf"
44+
issuer: http://localhost:8888/auth/realms/opentdf
45+
clients:
46+
- "opentdf"
4147
grpc:
4248
port: 9000
4349
reflectionEnabled: true # Default is false

internal/auth/authn.go

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log/slog"
7+
"net/http"
8+
"slices"
9+
"strings"
10+
"time"
11+
12+
"github.com/lestrrat-go/jwx/v2/jwk"
13+
"github.com/lestrrat-go/jwx/v2/jwt"
14+
"google.golang.org/grpc"
15+
"google.golang.org/grpc/codes"
16+
"google.golang.org/grpc/metadata"
17+
"google.golang.org/grpc/status"
18+
)
19+
20+
var (
21+
// Set of allowed gRPC endpoints that do not require authentication
22+
allowedGRPCEndpoints = [...]string{
23+
"/grpc.health.v1.Health/Check",
24+
"/wellknownconfiguration.WellKnownService/GetWellKnownConfiguration",
25+
}
26+
// Set of allowed HTTP endpoints that do not require authentication
27+
allowedHTTPEndpoints = [...]string{
28+
"/healthz",
29+
"/.well-known/opentdf-configuration",
30+
}
31+
)
32+
33+
// Authentication holds a jwks cache and information about the openid configuration
34+
type authentication struct {
35+
// cache holds the jwks cache
36+
cache *jwk.Cache
37+
// openidConfigurations holds the openid configuration for each issuer
38+
oidcConfigurations map[string]AuthNConfig
39+
}
40+
41+
// Creates new authN which is used to verify tokens for a set of given issuers
42+
func NewAuthenticator(cfg AuthNConfig) (*authentication, error) {
43+
a := &authentication{}
44+
a.oidcConfigurations = make(map[string]AuthNConfig)
45+
46+
ctx := context.Background()
47+
48+
a.cache = jwk.NewCache(ctx)
49+
50+
// Build new cache
51+
// Discover OIDC Configuration
52+
oidcConfig, err := DiscoverOIDCConfiguration(ctx, cfg.Issuer)
53+
if err != nil {
54+
return nil, err
55+
}
56+
57+
cfg.OIDCConfiguration = *oidcConfig
58+
59+
// Register the jwks_uri with the cache
60+
if err := a.cache.Register(cfg.JwksURI, jwk.WithMinRefreshInterval(15*time.Minute)); err != nil {
61+
return nil, err
62+
}
63+
64+
// Need to refresh the cache to verify jwks is available
65+
_, err = a.cache.Refresh(ctx, cfg.JwksURI)
66+
if err != nil {
67+
return nil, err
68+
}
69+
70+
a.oidcConfigurations[cfg.Issuer] = cfg
71+
72+
return a, nil
73+
}
74+
75+
// verifyTokenHandler is a http handler that verifies the token
76+
func (a authentication) VerifyTokenHandler(handler http.Handler) http.Handler {
77+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
78+
if slices.Contains(allowedHTTPEndpoints[:], r.URL.Path) {
79+
handler.ServeHTTP(w, r)
80+
return
81+
}
82+
// Verify the token
83+
header := r.Header["Authorization"]
84+
if len(header) < 1 {
85+
http.Error(w, "missing authorization header", http.StatusUnauthorized)
86+
return
87+
}
88+
err := checkToken(r.Context(), header, a)
89+
if err != nil {
90+
slog.WarnContext(r.Context(), "failed to validate token", slog.String("error", err.Error()))
91+
http.Error(w, "unauthenticated", http.StatusUnauthorized)
92+
return
93+
}
94+
95+
handler.ServeHTTP(w, r)
96+
})
97+
}
98+
99+
// verifyTokenInterceptor is a grpc interceptor that verifies the token in the metadata
100+
func (a authentication) VerifyTokenInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
101+
// Allow health checks to pass through
102+
if slices.Contains(allowedGRPCEndpoints[:], info.FullMethod) {
103+
return handler(ctx, req)
104+
}
105+
106+
// Get the metadata from the context
107+
// The keys within metadata.MD are normalized to lowercase.
108+
// See: https://godoc.org/google.golang.org/grpc/metadata#New
109+
md, ok := metadata.FromIncomingContext(ctx)
110+
if !ok {
111+
return nil, status.Error(codes.Unauthenticated, "missing metadata")
112+
}
113+
114+
// Verify the token
115+
header := md["authorization"]
116+
if len(header) < 1 {
117+
return nil, status.Error(codes.Unauthenticated, "missing authorization header")
118+
}
119+
120+
err := checkToken(ctx, header, a)
121+
if err != nil {
122+
slog.Warn("failed to validate token", slog.String("error", err.Error()))
123+
return nil, status.Errorf(codes.Unauthenticated, "unauthenticated")
124+
}
125+
126+
return handler(ctx, req)
127+
}
128+
129+
// checkToken is a helper function to verify the token.
130+
func checkToken(ctx context.Context, authHeader []string, auth authentication) error {
131+
var (
132+
tokenRaw string
133+
tokenType string
134+
)
135+
136+
// If we don't get a DPoP/Bearer token type, we can't proceed
137+
switch {
138+
case strings.HasPrefix(authHeader[0], "DPoP "):
139+
tokenType = "DPoP"
140+
tokenRaw = strings.TrimPrefix(authHeader[0], "DPoP ")
141+
case strings.HasPrefix(authHeader[0], "Bearer "):
142+
tokenType = "Bearer"
143+
tokenRaw = strings.TrimPrefix(authHeader[0], "Bearer ")
144+
default:
145+
return fmt.Errorf("not of type bearer or dpop")
146+
}
147+
148+
// Future work is to validate DPoP proof if token type is DPoP
149+
//nolint:staticcheck
150+
if tokenType == "DPoP" {
151+
// Implement in the future here or as separate interceptor
152+
}
153+
154+
// We have to get iss from the token first to verify the signature
155+
unverifiedToken, err := jwt.Parse([]byte(tokenRaw), jwt.WithVerify(false))
156+
if err != nil {
157+
return err
158+
}
159+
160+
// Get issuer from unverified token
161+
issuer, exists := unverifiedToken.Get("iss")
162+
if !exists {
163+
return fmt.Errorf("missing issuer")
164+
}
165+
166+
// Get the openid configuration for the issuer
167+
// Because we get an interface we need to cast it to a string
168+
// and jwx expects it as a string so we should never hit this error if the token is valid
169+
issuerStr, ok := issuer.(string)
170+
if !ok {
171+
return fmt.Errorf("invalid issuer")
172+
}
173+
oidc, exists := auth.oidcConfigurations[issuerStr]
174+
if !exists {
175+
return fmt.Errorf("invalid issuer")
176+
}
177+
178+
// Get key set from cache that matches the jwks_uri
179+
keySet, err := auth.cache.Get(ctx, oidc.JwksURI)
180+
if err != nil {
181+
return fmt.Errorf("failed to get jwks from cache")
182+
}
183+
184+
// Now we verify the token signature
185+
_, err = jwt.Parse([]byte(tokenRaw),
186+
jwt.WithKeySet(keySet),
187+
jwt.WithValidate(true),
188+
jwt.WithIssuer(issuerStr),
189+
jwt.WithAudience(oidc.Audience),
190+
jwt.WithValidator(jwt.ValidatorFunc(auth.claimsValidator)),
191+
)
192+
if err != nil {
193+
return err
194+
}
195+
196+
return nil
197+
}
198+
199+
// claimsValidator is a custom validator to check extra claims in the token.
200+
// right now it only checks for client_id
201+
func (a authentication) claimsValidator(ctx context.Context, token jwt.Token) jwt.ValidationError {
202+
var (
203+
clientID string
204+
)
205+
206+
// Need to check for cid and client_id as this claim seems to be different between idp's
207+
cidClaim, cidExists := token.Get("cid")
208+
clientIDClaim, clientIDExists := token.Get("client_id")
209+
210+
// Check to see if we have a client id claim
211+
switch {
212+
case cidExists:
213+
if cid, ok := cidClaim.(string); ok {
214+
clientID = cid
215+
break
216+
}
217+
case clientIDExists:
218+
if cid, ok := clientIDClaim.(string); ok {
219+
clientID = cid
220+
break
221+
}
222+
default:
223+
return jwt.NewValidationError(fmt.Errorf("client id required"))
224+
}
225+
226+
// Check if the client id is allowed in list of clients
227+
foundClientID := false
228+
for _, c := range a.oidcConfigurations[token.Issuer()].Clients {
229+
if c == clientID {
230+
foundClientID = true
231+
break
232+
}
233+
}
234+
if !foundClientID {
235+
return jwt.NewValidationError(fmt.Errorf("invalid client id"))
236+
}
237+
238+
return nil
239+
}

0 commit comments

Comments
 (0)