chore: migrate to gitea
Some checks failed
golangci-lint / lint (push) Failing after 21s
Test / test (push) Failing after 2m17s

This commit is contained in:
2026-01-27 01:40:31 +01:00
parent a9bca767a9
commit 1a27ed5274
3163 changed files with 1216358 additions and 1529 deletions

View File

@@ -0,0 +1 @@
**This directory has the implementation of the S2Av2's gRPC-Go client libraries**

View File

@@ -0,0 +1,122 @@
/*
*
* Copyright 2022 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package certverifier offloads verifications to S2Av2.
package certverifier
import (
"crypto/x509"
"fmt"
"github.com/google/s2a-go/stream"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
)
// VerifyClientCertificateChain builds a SessionReq, sends it to S2Av2 and
// receives a SessionResp.
func VerifyClientCertificateChain(verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, s2AStream stream.S2AStream) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// Offload verification to S2Av2.
if grpclog.V(1) {
grpclog.Infof("Sending request to S2Av2 for client peer cert chain validation.")
}
if err := s2AStream.Send(&s2av2pb.SessionReq{
ReqOneof: &s2av2pb.SessionReq_ValidatePeerCertificateChainReq{
ValidatePeerCertificateChainReq: &s2av2pb.ValidatePeerCertificateChainReq{
Mode: verificationMode,
PeerOneof: &s2av2pb.ValidatePeerCertificateChainReq_ClientPeer_{
ClientPeer: &s2av2pb.ValidatePeerCertificateChainReq_ClientPeer{
CertificateChain: rawCerts,
},
},
},
},
}); err != nil {
grpclog.Infof("Failed to send request to S2Av2 for client peer cert chain validation.")
return err
}
// Get the response from S2Av2.
resp, err := s2AStream.Recv()
if err != nil {
grpclog.Infof("Failed to receive client peer cert chain validation response from S2Av2.")
return err
}
// Parse the response.
if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
return fmt.Errorf("failed to offload client cert verification to S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
}
if resp.GetValidatePeerCertificateChainResp().ValidationResult != s2av2pb.ValidatePeerCertificateChainResp_SUCCESS {
return fmt.Errorf("client cert verification failed: %v", resp.GetValidatePeerCertificateChainResp().ValidationDetails)
}
return nil
}
}
// VerifyServerCertificateChain builds a SessionReq, sends it to S2Av2 and
// receives a SessionResp.
func VerifyServerCertificateChain(hostname string, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, s2AStream stream.S2AStream, serverAuthorizationPolicy []byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// Offload verification to S2Av2.
if grpclog.V(1) {
grpclog.Infof("Sending request to S2Av2 for server peer cert chain validation.")
}
if err := s2AStream.Send(&s2av2pb.SessionReq{
ReqOneof: &s2av2pb.SessionReq_ValidatePeerCertificateChainReq{
ValidatePeerCertificateChainReq: &s2av2pb.ValidatePeerCertificateChainReq{
Mode: verificationMode,
PeerOneof: &s2av2pb.ValidatePeerCertificateChainReq_ServerPeer_{
ServerPeer: &s2av2pb.ValidatePeerCertificateChainReq_ServerPeer{
CertificateChain: rawCerts,
ServerHostname: hostname,
SerializedUnrestrictedClientPolicy: serverAuthorizationPolicy,
},
},
},
},
}); err != nil {
grpclog.Infof("Failed to send request to S2Av2 for server peer cert chain validation.")
return err
}
// Get the response from S2Av2.
resp, err := s2AStream.Recv()
if err != nil {
grpclog.Infof("Failed to receive server peer cert chain validation response from S2Av2.")
return err
}
// Parse the response.
if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
return fmt.Errorf("failed to offload server cert verification to S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
}
if resp.GetValidatePeerCertificateChainResp().ValidationResult != s2av2pb.ValidatePeerCertificateChainResp_SUCCESS {
return fmt.Errorf("server cert verification failed: %v", resp.GetValidatePeerCertificateChainResp().ValidationDetails)
}
return nil
}
}

View File

@@ -0,0 +1,186 @@
/*
*
* Copyright 2022 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package remotesigner offloads private key operations to S2Av2.
package remotesigner
import (
"crypto"
"crypto/rsa"
"crypto/x509"
"fmt"
"io"
"github.com/google/s2a-go/stream"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
)
// remoteSigner implementes the crypto.Signer interface.
type remoteSigner struct {
leafCert *x509.Certificate
s2AStream stream.S2AStream
}
// New returns an instance of RemoteSigner, an implementation of the
// crypto.Signer interface.
func New(leafCert *x509.Certificate, s2AStream stream.S2AStream) crypto.Signer {
return &remoteSigner{leafCert, s2AStream}
}
func (s *remoteSigner) Public() crypto.PublicKey {
return s.leafCert.PublicKey
}
func (s *remoteSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
signatureAlgorithm, err := getSignatureAlgorithm(opts, s.leafCert)
if err != nil {
return nil, err
}
req, err := getSignReq(signatureAlgorithm, digest)
if err != nil {
return nil, err
}
if grpclog.V(1) {
grpclog.Infof("Sending request to S2Av2 for signing operation.")
}
if err := s.s2AStream.Send(&s2av2pb.SessionReq{
ReqOneof: &s2av2pb.SessionReq_OffloadPrivateKeyOperationReq{
OffloadPrivateKeyOperationReq: req,
},
}); err != nil {
grpclog.Infof("Failed to send request to S2Av2 for signing operation.")
return nil, err
}
resp, err := s.s2AStream.Recv()
if err != nil {
grpclog.Infof("Failed to receive signing operation response from S2Av2.")
return nil, err
}
if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
return nil, fmt.Errorf("failed to offload signing with private key to S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
}
return resp.GetOffloadPrivateKeyOperationResp().GetOutBytes(), nil
}
// getCert returns the leafCert field in s.
func (s *remoteSigner) getCert() *x509.Certificate {
return s.leafCert
}
// getStream returns the s2AStream field in s.
func (s *remoteSigner) getStream() stream.S2AStream {
return s.s2AStream
}
func getSignReq(signatureAlgorithm s2av2pb.SignatureAlgorithm, digest []byte) (*s2av2pb.OffloadPrivateKeyOperationReq, error) {
if (signatureAlgorithm == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PKCS1_SHA256) || (signatureAlgorithm == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256) || (signatureAlgorithm == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256) {
return &s2av2pb.OffloadPrivateKeyOperationReq{
Operation: s2av2pb.OffloadPrivateKeyOperationReq_SIGN,
SignatureAlgorithm: signatureAlgorithm,
InBytes: &s2av2pb.OffloadPrivateKeyOperationReq_Sha256Digest{
Sha256Digest: digest,
},
}, nil
} else if (signatureAlgorithm == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PKCS1_SHA384) || (signatureAlgorithm == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384) || (signatureAlgorithm == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384) {
return &s2av2pb.OffloadPrivateKeyOperationReq{
Operation: s2av2pb.OffloadPrivateKeyOperationReq_SIGN,
SignatureAlgorithm: signatureAlgorithm,
InBytes: &s2av2pb.OffloadPrivateKeyOperationReq_Sha384Digest{
Sha384Digest: digest,
},
}, nil
} else if (signatureAlgorithm == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PKCS1_SHA512) || (signatureAlgorithm == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512) || (signatureAlgorithm == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512) || (signatureAlgorithm == s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_ED25519) {
return &s2av2pb.OffloadPrivateKeyOperationReq{
Operation: s2av2pb.OffloadPrivateKeyOperationReq_SIGN,
SignatureAlgorithm: signatureAlgorithm,
InBytes: &s2av2pb.OffloadPrivateKeyOperationReq_Sha512Digest{
Sha512Digest: digest,
},
}, nil
} else {
return nil, fmt.Errorf("unknown signature algorithm: %v", signatureAlgorithm)
}
}
// getSignatureAlgorithm returns the signature algorithm that S2A must use when
// performing a signing operation that has been offloaded by an application
// using the crypto/tls libraries.
func getSignatureAlgorithm(opts crypto.SignerOpts, leafCert *x509.Certificate) (s2av2pb.SignatureAlgorithm, error) {
if opts == nil || leafCert == nil {
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_UNSPECIFIED, fmt.Errorf("unknown signature algorithm")
}
switch leafCert.PublicKeyAlgorithm {
case x509.RSA:
if rsaPSSOpts, ok := opts.(*rsa.PSSOptions); ok {
return rsaPSSAlgorithm(rsaPSSOpts)
}
return rsaPPKCS1Algorithm(opts)
case x509.ECDSA:
return ecdsaAlgorithm(opts)
case x509.Ed25519:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_ED25519, nil
default:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_UNSPECIFIED, fmt.Errorf("unknown signature algorithm: %q", leafCert.PublicKeyAlgorithm)
}
}
func rsaPSSAlgorithm(opts *rsa.PSSOptions) (s2av2pb.SignatureAlgorithm, error) {
switch opts.HashFunc() {
case crypto.SHA256:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256, nil
case crypto.SHA384:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384, nil
case crypto.SHA512:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512, nil
default:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_UNSPECIFIED, fmt.Errorf("unknown signature algorithm")
}
}
func rsaPPKCS1Algorithm(opts crypto.SignerOpts) (s2av2pb.SignatureAlgorithm, error) {
switch opts.HashFunc() {
case crypto.SHA256:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PKCS1_SHA256, nil
case crypto.SHA384:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PKCS1_SHA384, nil
case crypto.SHA512:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_RSA_PKCS1_SHA512, nil
default:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_UNSPECIFIED, fmt.Errorf("unknown signature algorithm")
}
}
func ecdsaAlgorithm(opts crypto.SignerOpts) (s2av2pb.SignatureAlgorithm, error) {
switch opts.HashFunc() {
case crypto.SHA256:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256, nil
case crypto.SHA384:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384, nil
case crypto.SHA512:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512, nil
default:
return s2av2pb.SignatureAlgorithm_S2A_SSL_SIGN_UNSPECIFIED, fmt.Errorf("unknown signature algorithm")
}
}

380
vendor/github.com/google/s2a-go/internal/v2/s2av2.go generated vendored Normal file
View File

@@ -0,0 +1,380 @@
/*
*
* Copyright 2022 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package v2 provides the S2Av2 transport credentials used by a gRPC
// application.
package v2
import (
"context"
"crypto/tls"
"errors"
"net"
"os"
"time"
"github.com/google/s2a-go/fallback"
"github.com/google/s2a-go/internal/handshaker/service"
"github.com/google/s2a-go/internal/tokenmanager"
"github.com/google/s2a-go/internal/v2/tlsconfigstore"
"github.com/google/s2a-go/retry"
"github.com/google/s2a-go/stream"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/protobuf/proto"
commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
)
const (
s2aSecurityProtocol = "tls"
defaultS2ATimeout = 6 * time.Second
)
// An environment variable, which sets the timeout enforced on the connection to the S2A service for handshake.
const s2aTimeoutEnv = "S2A_TIMEOUT"
type s2av2TransportCreds struct {
info *credentials.ProtocolInfo
isClient bool
serverName string
s2av2Address string
transportCreds credentials.TransportCredentials
tokenManager *tokenmanager.AccessTokenManager
// localIdentity should only be used by the client.
localIdentity *commonpb.Identity
// localIdentities should only be used by the server.
localIdentities []*commonpb.Identity
verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
fallbackClientHandshake fallback.ClientHandshake
getS2AStream stream.GetS2AStream
serverAuthorizationPolicy []byte
}
// NewClientCreds returns a client-side transport credentials object that uses
// the S2Av2 to establish a secure connection with a server.
func NewClientCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentity *commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, fallbackClientHandshakeFunc fallback.ClientHandshake, getS2AStream stream.GetS2AStream, serverAuthorizationPolicy []byte) (credentials.TransportCredentials, error) {
// Create an AccessTokenManager instance to use to authenticate to S2Av2.
accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
creds := &s2av2TransportCreds{
info: &credentials.ProtocolInfo{
SecurityProtocol: s2aSecurityProtocol,
},
isClient: true,
serverName: "",
s2av2Address: s2av2Address,
transportCreds: transportCreds,
localIdentity: localIdentity,
verificationMode: verificationMode,
fallbackClientHandshake: fallbackClientHandshakeFunc,
getS2AStream: getS2AStream,
serverAuthorizationPolicy: serverAuthorizationPolicy,
}
if err != nil {
creds.tokenManager = nil
} else {
creds.tokenManager = &accessTokenManager
}
if grpclog.V(1) {
grpclog.Info("Created client S2Av2 transport credentials.")
}
return creds, nil
}
// NewServerCreds returns a server-side transport credentials object that uses
// the S2Av2 to establish a secure connection with a client.
func NewServerCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, getS2AStream stream.GetS2AStream) (credentials.TransportCredentials, error) {
// Create an AccessTokenManager instance to use to authenticate to S2Av2.
accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
creds := &s2av2TransportCreds{
info: &credentials.ProtocolInfo{
SecurityProtocol: s2aSecurityProtocol,
},
isClient: false,
s2av2Address: s2av2Address,
transportCreds: transportCreds,
localIdentities: localIdentities,
verificationMode: verificationMode,
getS2AStream: getS2AStream,
}
if err != nil {
creds.tokenManager = nil
} else {
creds.tokenManager = &accessTokenManager
}
if grpclog.V(1) {
grpclog.Info("Created server S2Av2 transport credentials.")
}
return creds, nil
}
// ClientHandshake performs a client-side mTLS handshake using the S2Av2.
func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if !c.isClient {
return nil, nil, errors.New("client handshake called using server transport credentials")
}
// Remove the port from serverAuthority.
serverName := removeServerNamePort(serverAuthority)
timeoutCtx, cancel := context.WithTimeout(ctx, GetS2ATimeout())
defer cancel()
var s2AStream stream.S2AStream
var err error
retry.Run(timeoutCtx,
func() error {
s2AStream, err = createStream(timeoutCtx, c.s2av2Address, c.transportCreds, c.getS2AStream)
return err
})
if err != nil {
grpclog.Infof("Failed to connect to S2Av2: %v", err)
if c.fallbackClientHandshake != nil {
return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
}
return nil, nil, err
}
defer s2AStream.CloseSend()
if grpclog.V(1) {
grpclog.Infof("Connected to S2Av2.")
}
var config *tls.Config
var tokenManager tokenmanager.AccessTokenManager
if c.tokenManager == nil {
tokenManager = nil
} else {
tokenManager = *c.tokenManager
}
sn := serverName
if c.serverName != "" {
sn = c.serverName
}
retry.Run(timeoutCtx,
func() error {
config, err = tlsconfigstore.GetTLSConfigurationForClient(sn, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
return err
})
if err != nil {
grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
if c.fallbackClientHandshake != nil {
return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
}
return nil, nil, err
}
if grpclog.V(1) {
grpclog.Infof("Got client TLS config from S2Av2.")
}
creds := credentials.NewTLS(config)
conn, authInfo, err := creds.ClientHandshake(timeoutCtx, serverName, rawConn)
if err != nil {
grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
if c.fallbackClientHandshake != nil {
return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
}
return nil, nil, err
}
grpclog.Infof("client-side handshake is done using S2Av2 to: %s", serverName)
return conn, authInfo, err
}
// ServerHandshake performs a server-side mTLS handshake using the S2Av2.
func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if c.isClient {
return nil, nil, errors.New("server handshake called using client transport credentials")
}
ctx, cancel := context.WithTimeout(context.Background(), GetS2ATimeout())
defer cancel()
var s2AStream stream.S2AStream
var err error
retry.Run(ctx,
func() error {
s2AStream, err = createStream(ctx, c.s2av2Address, c.transportCreds, c.getS2AStream)
return err
})
if err != nil {
grpclog.Infof("Failed to connect to S2Av2: %v", err)
return nil, nil, err
}
defer s2AStream.CloseSend()
if grpclog.V(1) {
grpclog.Infof("Connected to S2Av2.")
}
var tokenManager tokenmanager.AccessTokenManager
if c.tokenManager == nil {
tokenManager = nil
} else {
tokenManager = *c.tokenManager
}
var config *tls.Config
retry.Run(ctx,
func() error {
config, err = tlsconfigstore.GetTLSConfigurationForServer(s2AStream, tokenManager, c.localIdentities, c.verificationMode)
return err
})
if err != nil {
grpclog.Infof("Failed to get server TLS config from S2Av2: %v", err)
return nil, nil, err
}
if grpclog.V(1) {
grpclog.Infof("Got server TLS config from S2Av2.")
}
creds := credentials.NewTLS(config)
conn, authInfo, err := creds.ServerHandshake(rawConn)
if err != nil {
grpclog.Infof("Failed to do server handshake using S2Av2: %v", err)
return nil, nil, err
}
return conn, authInfo, err
}
// Info returns protocol info of s2av2TransportCreds.
func (c *s2av2TransportCreds) Info() credentials.ProtocolInfo {
return *c.info
}
// Clone makes a deep copy of s2av2TransportCreds.
func (c *s2av2TransportCreds) Clone() credentials.TransportCredentials {
info := *c.info
serverName := c.serverName
fallbackClientHandshake := c.fallbackClientHandshake
s2av2Address := c.s2av2Address
var tokenManager tokenmanager.AccessTokenManager
if c.tokenManager == nil {
tokenManager = nil
} else {
tokenManager = *c.tokenManager
}
verificationMode := c.verificationMode
var localIdentity *commonpb.Identity
if c.localIdentity != nil {
localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
}
var localIdentities []*commonpb.Identity
if c.localIdentities != nil {
localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
for i, localIdentity := range c.localIdentities {
localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
}
}
creds := &s2av2TransportCreds{
info: &info,
isClient: c.isClient,
serverName: serverName,
fallbackClientHandshake: fallbackClientHandshake,
s2av2Address: s2av2Address,
localIdentity: localIdentity,
localIdentities: localIdentities,
verificationMode: verificationMode,
}
if c.tokenManager == nil {
creds.tokenManager = nil
} else {
creds.tokenManager = &tokenManager
}
return creds
}
// NewClientTLSConfig returns a tls.Config instance that uses S2Av2 to establish a TLS connection as
// a client. The tls.Config MUST only be used to establish a single TLS connection.
func NewClientTLSConfig(
ctx context.Context,
s2av2Address string,
transportCreds credentials.TransportCredentials,
tokenManager tokenmanager.AccessTokenManager,
verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
serverName string,
serverAuthorizationPolicy []byte,
getStream stream.GetS2AStream) (*tls.Config, error) {
s2AStream, err := createStream(ctx, s2av2Address, transportCreds, getStream)
if err != nil {
grpclog.Infof("Failed to connect to S2Av2: %v", err)
return nil, err
}
return tlsconfigstore.GetTLSConfigurationForClient(removeServerNamePort(serverName), s2AStream, tokenManager, nil, verificationMode, serverAuthorizationPolicy)
}
// OverrideServerName sets the ServerName in the s2av2TransportCreds protocol
// info. The ServerName MUST be a hostname.
func (c *s2av2TransportCreds) OverrideServerName(serverNameOverride string) error {
serverName := removeServerNamePort(serverNameOverride)
c.info.ServerName = serverName
c.serverName = serverName
return nil
}
// Remove the trailing port from server name.
func removeServerNamePort(serverName string) string {
name, _, err := net.SplitHostPort(serverName)
if err != nil {
name = serverName
}
return name
}
type s2AGrpcStream struct {
stream s2av2pb.S2AService_SetUpSessionClient
}
func (x s2AGrpcStream) Send(m *s2av2pb.SessionReq) error {
return x.stream.Send(m)
}
func (x s2AGrpcStream) Recv() (*s2av2pb.SessionResp, error) {
return x.stream.Recv()
}
func (x s2AGrpcStream) CloseSend() error {
return x.stream.CloseSend()
}
func createStream(ctx context.Context, s2av2Address string, transportCreds credentials.TransportCredentials, getS2AStream stream.GetS2AStream) (stream.S2AStream, error) {
if getS2AStream != nil {
return getS2AStream(ctx, s2av2Address)
}
// TODO(rmehta19): Consider whether to close the connection to S2Av2.
conn, err := service.Dial(ctx, s2av2Address, transportCreds)
if err != nil {
return nil, err
}
client := s2av2pb.NewS2AServiceClient(conn)
gRPCStream, err := client.SetUpSession(ctx, []grpc.CallOption{}...)
if err != nil {
return nil, err
}
return &s2AGrpcStream{
stream: gRPCStream,
}, nil
}
// GetS2ATimeout returns the timeout enforced on the connection to the S2A service for handshake.
func GetS2ATimeout() time.Duration {
timeout, err := time.ParseDuration(os.Getenv(s2aTimeoutEnv))
if err != nil {
return defaultS2ATimeout
}
return timeout
}

View File

@@ -0,0 +1,403 @@
/*
*
* Copyright 2022 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package tlsconfigstore offloads operations to S2Av2.
package tlsconfigstore
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"github.com/google/s2a-go/internal/tokenmanager"
"github.com/google/s2a-go/internal/v2/certverifier"
"github.com/google/s2a-go/internal/v2/remotesigner"
"github.com/google/s2a-go/stream"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
)
const (
// HTTP/2
h2 = "h2"
)
// GetTLSConfigurationForClient returns a tls.Config instance for use by a client application.
func GetTLSConfigurationForClient(serverHostname string, s2AStream stream.S2AStream, tokenManager tokenmanager.AccessTokenManager, localIdentity *commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, serverAuthorizationPolicy []byte) (*tls.Config, error) {
authMechanisms := getAuthMechanisms(tokenManager, []*commonpb.Identity{localIdentity})
if grpclog.V(1) {
grpclog.Infof("Sending request to S2Av2 for client TLS config.")
}
// Send request to S2Av2 for config.
if err := s2AStream.Send(&s2av2pb.SessionReq{
LocalIdentity: localIdentity,
AuthenticationMechanisms: authMechanisms,
ReqOneof: &s2av2pb.SessionReq_GetTlsConfigurationReq{
GetTlsConfigurationReq: &s2av2pb.GetTlsConfigurationReq{
ConnectionSide: commonpb.ConnectionSide_CONNECTION_SIDE_CLIENT,
},
},
}); err != nil {
grpclog.Infof("Failed to send request to S2Av2 for client TLS config")
return nil, err
}
// Get the response containing config from S2Av2.
resp, err := s2AStream.Recv()
if err != nil {
grpclog.Infof("Failed to receive client TLS config response from S2Av2.")
return nil, err
}
// TODO(rmehta19): Add unit test for this if statement.
if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
}
// Extract TLS configuration from SessionResp.
tlsConfig := resp.GetGetTlsConfigurationResp().GetClientTlsConfiguration()
var cert tls.Certificate
for i, v := range tlsConfig.CertificateChain {
// Populate Certificates field.
block, _ := pem.Decode([]byte(v))
if block == nil {
return nil, errors.New("certificate in CertificateChain obtained from S2Av2 is empty")
}
x509Cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, err
}
cert.Certificate = append(cert.Certificate, x509Cert.Raw)
if i == 0 {
cert.Leaf = x509Cert
}
}
if len(tlsConfig.CertificateChain) > 0 {
cert.PrivateKey = remotesigner.New(cert.Leaf, s2AStream)
if cert.PrivateKey == nil {
return nil, errors.New("failed to retrieve Private Key from Remote Signer Library")
}
}
minVersion, maxVersion, err := getTLSMinMaxVersionsClient(tlsConfig)
if err != nil {
return nil, err
}
// Create mTLS credentials for client.
config := &tls.Config{
VerifyPeerCertificate: certverifier.VerifyServerCertificateChain(serverHostname, verificationMode, s2AStream, serverAuthorizationPolicy),
ServerName: serverHostname,
InsecureSkipVerify: true, // NOLINT
ClientSessionCache: nil,
SessionTicketsDisabled: true,
MinVersion: minVersion,
MaxVersion: maxVersion,
NextProtos: []string{h2},
}
if len(tlsConfig.CertificateChain) > 0 {
config.Certificates = []tls.Certificate{cert}
}
return config, nil
}
// GetTLSConfigurationForServer returns a tls.Config instance for use by a server application.
func GetTLSConfigurationForServer(s2AStream stream.S2AStream, tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode) (*tls.Config, error) {
return &tls.Config{
GetConfigForClient: ClientConfig(tokenManager, localIdentities, verificationMode, s2AStream),
}, nil
}
// ClientConfig builds a TLS config for a server to establish a secure
// connection with a client, based on SNI communicated during ClientHello.
// Ensures that server presents the correct certificate to establish a TLS
// connection.
func ClientConfig(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, s2AStream stream.S2AStream) func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
return func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
tlsConfig, err := getServerConfigFromS2Av2(tokenManager, localIdentities, chi.ServerName, s2AStream)
if err != nil {
return nil, err
}
var cert tls.Certificate
for i, v := range tlsConfig.CertificateChain {
// Populate Certificates field.
block, _ := pem.Decode([]byte(v))
if block == nil {
return nil, errors.New("certificate in CertificateChain obtained from S2Av2 is empty")
}
x509Cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, err
}
cert.Certificate = append(cert.Certificate, x509Cert.Raw)
if i == 0 {
cert.Leaf = x509Cert
}
}
cert.PrivateKey = remotesigner.New(cert.Leaf, s2AStream)
if cert.PrivateKey == nil {
return nil, errors.New("failed to retrieve Private Key from Remote Signer Library")
}
minVersion, maxVersion, err := getTLSMinMaxVersionsServer(tlsConfig)
if err != nil {
return nil, err
}
clientAuth := getTLSClientAuthType(tlsConfig)
var cipherSuites []uint16
cipherSuites = getCipherSuites(tlsConfig.Ciphersuites)
// Create mTLS credentials for server.
return &tls.Config{
Certificates: []tls.Certificate{cert},
VerifyPeerCertificate: certverifier.VerifyClientCertificateChain(verificationMode, s2AStream),
ClientAuth: clientAuth,
CipherSuites: cipherSuites,
SessionTicketsDisabled: true,
MinVersion: minVersion,
MaxVersion: maxVersion,
NextProtos: []string{h2},
}, nil
}
}
func getCipherSuites(tlsConfigCipherSuites []commonpb.Ciphersuite) []uint16 {
var tlsGoCipherSuites []uint16
for _, v := range tlsConfigCipherSuites {
s := getTLSCipherSuite(v)
if s != 0xffff {
tlsGoCipherSuites = append(tlsGoCipherSuites, s)
}
}
return tlsGoCipherSuites
}
func getTLSCipherSuite(tlsCipherSuite commonpb.Ciphersuite) uint16 {
switch tlsCipherSuite {
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
return tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
return tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384:
return tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256:
return tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
default:
return 0xffff
}
}
func getServerConfigFromS2Av2(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity, sni string, s2AStream stream.S2AStream) (*s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration, error) {
authMechanisms := getAuthMechanisms(tokenManager, localIdentities)
var locID *commonpb.Identity
if localIdentities != nil {
locID = localIdentities[0]
}
if err := s2AStream.Send(&s2av2pb.SessionReq{
LocalIdentity: locID,
AuthenticationMechanisms: authMechanisms,
ReqOneof: &s2av2pb.SessionReq_GetTlsConfigurationReq{
GetTlsConfigurationReq: &s2av2pb.GetTlsConfigurationReq{
ConnectionSide: commonpb.ConnectionSide_CONNECTION_SIDE_SERVER,
Sni: sni,
},
},
}); err != nil {
return nil, err
}
resp, err := s2AStream.Recv()
if err != nil {
return nil, err
}
// TODO(rmehta19): Add unit test for this if statement.
if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
}
return resp.GetGetTlsConfigurationResp().GetServerTlsConfiguration(), nil
}
func getTLSClientAuthType(tlsConfig *s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration) tls.ClientAuthType {
var clientAuth tls.ClientAuthType
switch x := tlsConfig.RequestClientCertificate; x {
case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_DONT_REQUEST_CLIENT_CERTIFICATE:
clientAuth = tls.NoClientCert
case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
clientAuth = tls.RequestClientCert
case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
// This case actually maps to tls.VerifyClientCertIfGiven. However this
// mapping triggers normal verification, followed by custom verification,
// specified in VerifyPeerCertificate. To bypass normal verification, and
// only do custom verification we set clientAuth to RequireAnyClientCert or
// RequestClientCert. See https://github.com/google/s2a-go/pull/43 for full
// discussion.
clientAuth = tls.RequireAnyClientCert
case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
clientAuth = tls.RequireAnyClientCert
case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
// This case actually maps to tls.RequireAndVerifyClientCert. However this
// mapping triggers normal verification, followed by custom verification,
// specified in VerifyPeerCertificate. To bypass normal verification, and
// only do custom verification we set clientAuth to RequireAnyClientCert or
// RequestClientCert. See https://github.com/google/s2a-go/pull/43 for full
// discussion.
clientAuth = tls.RequireAnyClientCert
default:
clientAuth = tls.RequireAnyClientCert
}
return clientAuth
}
func getAuthMechanisms(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity) []*s2av2pb.AuthenticationMechanism {
if tokenManager == nil {
return nil
}
if len(localIdentities) == 0 {
token, err := tokenManager.DefaultToken()
if err != nil {
grpclog.Infof("Unable to get token for empty local identity: %v", err)
return nil
}
return []*s2av2pb.AuthenticationMechanism{
{
MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
Token: token,
},
},
}
}
var authMechanisms []*s2av2pb.AuthenticationMechanism
for _, localIdentity := range localIdentities {
if localIdentity == nil {
token, err := tokenManager.DefaultToken()
if err != nil {
grpclog.Infof("Unable to get default token for local identity %v: %v", localIdentity, err)
continue
}
authMechanisms = append(authMechanisms, &s2av2pb.AuthenticationMechanism{
Identity: localIdentity,
MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
Token: token,
},
})
} else {
token, err := tokenManager.Token(localIdentity)
if err != nil {
grpclog.Infof("Unable to get token for local identity %v: %v", localIdentity, err)
continue
}
authMechanisms = append(authMechanisms, &s2av2pb.AuthenticationMechanism{
Identity: localIdentity,
MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
Token: token,
},
})
}
}
return authMechanisms
}
// TODO(rmehta19): refactor switch statements into a helper function.
func getTLSMinMaxVersionsClient(tlsConfig *s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration) (uint16, uint16, error) {
// Map S2Av2 TLSVersion to consts defined in tls package.
var minVersion uint16
var maxVersion uint16
switch x := tlsConfig.MinTlsVersion; x {
case commonpb.TLSVersion_TLS_VERSION_1_0:
minVersion = tls.VersionTLS10
case commonpb.TLSVersion_TLS_VERSION_1_1:
minVersion = tls.VersionTLS11
case commonpb.TLSVersion_TLS_VERSION_1_2:
minVersion = tls.VersionTLS12
case commonpb.TLSVersion_TLS_VERSION_1_3:
minVersion = tls.VersionTLS13
default:
return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MinTlsVersion: %v", x)
}
switch x := tlsConfig.MaxTlsVersion; x {
case commonpb.TLSVersion_TLS_VERSION_1_0:
maxVersion = tls.VersionTLS10
case commonpb.TLSVersion_TLS_VERSION_1_1:
maxVersion = tls.VersionTLS11
case commonpb.TLSVersion_TLS_VERSION_1_2:
maxVersion = tls.VersionTLS12
case commonpb.TLSVersion_TLS_VERSION_1_3:
maxVersion = tls.VersionTLS13
default:
return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MaxTlsVersion: %v", x)
}
if minVersion > maxVersion {
return minVersion, maxVersion, errors.New("S2Av2 provided minVersion > maxVersion")
}
return minVersion, maxVersion, nil
}
func getTLSMinMaxVersionsServer(tlsConfig *s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration) (uint16, uint16, error) {
// Map S2Av2 TLSVersion to consts defined in tls package.
var minVersion uint16
var maxVersion uint16
switch x := tlsConfig.MinTlsVersion; x {
case commonpb.TLSVersion_TLS_VERSION_1_0:
minVersion = tls.VersionTLS10
case commonpb.TLSVersion_TLS_VERSION_1_1:
minVersion = tls.VersionTLS11
case commonpb.TLSVersion_TLS_VERSION_1_2:
minVersion = tls.VersionTLS12
case commonpb.TLSVersion_TLS_VERSION_1_3:
minVersion = tls.VersionTLS13
default:
return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MinTlsVersion: %v", x)
}
switch x := tlsConfig.MaxTlsVersion; x {
case commonpb.TLSVersion_TLS_VERSION_1_0:
maxVersion = tls.VersionTLS10
case commonpb.TLSVersion_TLS_VERSION_1_1:
maxVersion = tls.VersionTLS11
case commonpb.TLSVersion_TLS_VERSION_1_2:
maxVersion = tls.VersionTLS12
case commonpb.TLSVersion_TLS_VERSION_1_3:
maxVersion = tls.VersionTLS13
default:
return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MaxTlsVersion: %v", x)
}
if minVersion > maxVersion {
return minVersion, maxVersion, errors.New("S2Av2 provided minVersion > maxVersion")
}
return minVersion, maxVersion, nil
}