mirror of
https://github.com/Luzifer/nginx-sso.git
synced 2025-01-04 12:06:03 +00:00
331 lines
11 KiB
Go
331 lines
11 KiB
Go
|
/*
|
||
|
*
|
||
|
* Copyright 2018 gRPC authors.
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
// Package alts implements the ALTS credential support by gRPC library, which
|
||
|
// encapsulates all the state needed by a client to authenticate with a server
|
||
|
// using ALTS and make various assertions, e.g., about the client's identity,
|
||
|
// role, or whether it is authorized to make a particular call.
|
||
|
// This package is experimental.
|
||
|
package alts
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"google.golang.org/grpc/credentials"
|
||
|
core "google.golang.org/grpc/credentials/alts/internal"
|
||
|
"google.golang.org/grpc/credentials/alts/internal/handshaker"
|
||
|
"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
|
||
|
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
|
||
|
"google.golang.org/grpc/grpclog"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
|
||
|
// handshaker service address in the hypervisor.
|
||
|
hypervisorHandshakerServiceAddress = "metadata.google.internal:8080"
|
||
|
// defaultTimeout specifies the server handshake timeout.
|
||
|
defaultTimeout = 30.0 * time.Second
|
||
|
// The following constants specify the minimum and maximum acceptable
|
||
|
// protocol versions.
|
||
|
protocolVersionMaxMajor = 2
|
||
|
protocolVersionMaxMinor = 1
|
||
|
protocolVersionMinMajor = 2
|
||
|
protocolVersionMinMinor = 1
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
once sync.Once
|
||
|
maxRPCVersion = &altspb.RpcProtocolVersions_Version{
|
||
|
Major: protocolVersionMaxMajor,
|
||
|
Minor: protocolVersionMaxMinor,
|
||
|
}
|
||
|
minRPCVersion = &altspb.RpcProtocolVersions_Version{
|
||
|
Major: protocolVersionMinMajor,
|
||
|
Minor: protocolVersionMinMinor,
|
||
|
}
|
||
|
// ErrUntrustedPlatform is returned from ClientHandshake and
|
||
|
// ServerHandshake is running on a platform where the trustworthiness of
|
||
|
// the handshaker service is not guaranteed.
|
||
|
ErrUntrustedPlatform = errors.New("ALTS: untrusted platform. ALTS is only supported on GCP")
|
||
|
)
|
||
|
|
||
|
// AuthInfo exposes security information from the ALTS handshake to the
|
||
|
// application. This interface is to be implemented by ALTS. Users should not
|
||
|
// need a brand new implementation of this interface. For situations like
|
||
|
// testing, any new implementation should embed this interface. This allows
|
||
|
// ALTS to add new methods to this interface.
|
||
|
type AuthInfo interface {
|
||
|
// ApplicationProtocol returns application protocol negotiated for the
|
||
|
// ALTS connection.
|
||
|
ApplicationProtocol() string
|
||
|
// RecordProtocol returns the record protocol negotiated for the ALTS
|
||
|
// connection.
|
||
|
RecordProtocol() string
|
||
|
// SecurityLevel returns the security level of the created ALTS secure
|
||
|
// channel.
|
||
|
SecurityLevel() altspb.SecurityLevel
|
||
|
// PeerServiceAccount returns the peer service account.
|
||
|
PeerServiceAccount() string
|
||
|
// LocalServiceAccount returns the local service account.
|
||
|
LocalServiceAccount() string
|
||
|
// PeerRPCVersions returns the RPC version supported by the peer.
|
||
|
PeerRPCVersions() *altspb.RpcProtocolVersions
|
||
|
}
|
||
|
|
||
|
// ClientOptions contains the client-side options of an ALTS channel. These
|
||
|
// options will be passed to the underlying ALTS handshaker.
|
||
|
type ClientOptions struct {
|
||
|
// TargetServiceAccounts contains a list of expected target service
|
||
|
// accounts.
|
||
|
TargetServiceAccounts []string
|
||
|
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
|
||
|
// address to connect to.
|
||
|
HandshakerServiceAddress string
|
||
|
}
|
||
|
|
||
|
// DefaultClientOptions creates a new ClientOptions object with the default
|
||
|
// values.
|
||
|
func DefaultClientOptions() *ClientOptions {
|
||
|
return &ClientOptions{
|
||
|
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ServerOptions contains the server-side options of an ALTS channel. These
|
||
|
// options will be passed to the underlying ALTS handshaker.
|
||
|
type ServerOptions struct {
|
||
|
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
|
||
|
// address to connect to.
|
||
|
HandshakerServiceAddress string
|
||
|
}
|
||
|
|
||
|
// DefaultServerOptions creates a new ServerOptions object with the default
|
||
|
// values.
|
||
|
func DefaultServerOptions() *ServerOptions {
|
||
|
return &ServerOptions{
|
||
|
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// altsTC is the credentials required for authenticating a connection using ALTS.
|
||
|
// It implements credentials.TransportCredentials interface.
|
||
|
type altsTC struct {
|
||
|
info *credentials.ProtocolInfo
|
||
|
side core.Side
|
||
|
accounts []string
|
||
|
hsAddress string
|
||
|
}
|
||
|
|
||
|
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
|
||
|
func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
|
||
|
return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
|
||
|
}
|
||
|
|
||
|
// NewServerCreds constructs a server-side ALTS TransportCredentials object.
|
||
|
func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
|
||
|
return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
|
||
|
}
|
||
|
|
||
|
func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
|
||
|
once.Do(func() {
|
||
|
vmOnGCP = isRunningOnGCP()
|
||
|
})
|
||
|
|
||
|
if hsAddress == "" {
|
||
|
hsAddress = hypervisorHandshakerServiceAddress
|
||
|
}
|
||
|
return &altsTC{
|
||
|
info: &credentials.ProtocolInfo{
|
||
|
SecurityProtocol: "alts",
|
||
|
SecurityVersion: "1.0",
|
||
|
},
|
||
|
side: side,
|
||
|
accounts: accounts,
|
||
|
hsAddress: hsAddress,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ClientHandshake implements the client side handshake protocol.
|
||
|
func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
|
||
|
if !vmOnGCP {
|
||
|
return nil, nil, ErrUntrustedPlatform
|
||
|
}
|
||
|
|
||
|
// Connecting to ALTS handshaker service.
|
||
|
hsConn, err := service.Dial(g.hsAddress)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
// Do not close hsConn since it is shared with other handshakes.
|
||
|
|
||
|
// Possible context leak:
|
||
|
// The cancel function for the child context we create will only be
|
||
|
// called a non-nil error is returned.
|
||
|
var cancel context.CancelFunc
|
||
|
ctx, cancel = context.WithCancel(ctx)
|
||
|
defer func() {
|
||
|
if err != nil {
|
||
|
cancel()
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
opts := handshaker.DefaultClientHandshakerOptions()
|
||
|
opts.TargetName = addr
|
||
|
opts.TargetServiceAccounts = g.accounts
|
||
|
opts.RPCVersions = &altspb.RpcProtocolVersions{
|
||
|
MaxRpcVersion: maxRPCVersion,
|
||
|
MinRpcVersion: minRPCVersion,
|
||
|
}
|
||
|
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
|
||
|
defer func() {
|
||
|
if err != nil {
|
||
|
chs.Close()
|
||
|
}
|
||
|
}()
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
secConn, authInfo, err := chs.ClientHandshake(ctx)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
altsAuthInfo, ok := authInfo.(AuthInfo)
|
||
|
if !ok {
|
||
|
return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
|
||
|
}
|
||
|
match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
|
||
|
if !match {
|
||
|
return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
|
||
|
}
|
||
|
return secConn, authInfo, nil
|
||
|
}
|
||
|
|
||
|
// ServerHandshake implements the server side ALTS handshaker.
|
||
|
func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
|
||
|
if !vmOnGCP {
|
||
|
return nil, nil, ErrUntrustedPlatform
|
||
|
}
|
||
|
// Connecting to ALTS handshaker service.
|
||
|
hsConn, err := service.Dial(g.hsAddress)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
// Do not close hsConn since it's shared with other handshakes.
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
||
|
defer cancel()
|
||
|
opts := handshaker.DefaultServerHandshakerOptions()
|
||
|
opts.RPCVersions = &altspb.RpcProtocolVersions{
|
||
|
MaxRpcVersion: maxRPCVersion,
|
||
|
MinRpcVersion: minRPCVersion,
|
||
|
}
|
||
|
shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
|
||
|
defer func() {
|
||
|
if err != nil {
|
||
|
shs.Close()
|
||
|
}
|
||
|
}()
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
secConn, authInfo, err := shs.ServerHandshake(ctx)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
altsAuthInfo, ok := authInfo.(AuthInfo)
|
||
|
if !ok {
|
||
|
return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
|
||
|
}
|
||
|
match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
|
||
|
if !match {
|
||
|
return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
|
||
|
}
|
||
|
return secConn, authInfo, nil
|
||
|
}
|
||
|
|
||
|
func (g *altsTC) Info() credentials.ProtocolInfo {
|
||
|
return *g.info
|
||
|
}
|
||
|
|
||
|
func (g *altsTC) Clone() credentials.TransportCredentials {
|
||
|
info := *g.info
|
||
|
var accounts []string
|
||
|
if g.accounts != nil {
|
||
|
accounts = make([]string, len(g.accounts))
|
||
|
copy(accounts, g.accounts)
|
||
|
}
|
||
|
return &altsTC{
|
||
|
info: &info,
|
||
|
side: g.side,
|
||
|
hsAddress: g.hsAddress,
|
||
|
accounts: accounts,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (g *altsTC) OverrideServerName(serverNameOverride string) error {
|
||
|
g.info.ServerName = serverNameOverride
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2.
|
||
|
func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
|
||
|
switch {
|
||
|
case v1.GetMajor() > v2.GetMajor(),
|
||
|
v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
|
||
|
return 1
|
||
|
case v1.GetMajor() < v2.GetMajor(),
|
||
|
v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
|
||
|
return -1
|
||
|
}
|
||
|
return 0
|
||
|
}
|
||
|
|
||
|
// checkRPCVersions performs a version check between local and peer rpc protocol
|
||
|
// versions. This function returns true if the check passes which means both
|
||
|
// parties agreed on a common rpc protocol to use, and false otherwise. The
|
||
|
// function also returns the highest common RPC protocol version both parties
|
||
|
// agreed on.
|
||
|
func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
|
||
|
if local == nil || peer == nil {
|
||
|
grpclog.Error("invalid checkRPCVersions argument, either local or peer is nil.")
|
||
|
return false, nil
|
||
|
}
|
||
|
|
||
|
// maxCommonVersion is MIN(local.max, peer.max).
|
||
|
maxCommonVersion := local.GetMaxRpcVersion()
|
||
|
if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
|
||
|
maxCommonVersion = peer.GetMaxRpcVersion()
|
||
|
}
|
||
|
|
||
|
// minCommonVersion is MAX(local.min, peer.min).
|
||
|
minCommonVersion := peer.GetMinRpcVersion()
|
||
|
if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
|
||
|
minCommonVersion = local.GetMinRpcVersion()
|
||
|
}
|
||
|
|
||
|
if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
|
||
|
return false, nil
|
||
|
}
|
||
|
return true, maxCommonVersion
|
||
|
}
|