mirror of
https://github.com/Luzifer/nginx-sso.git
synced 2024-12-21 05:11:17 +00:00
290 lines
7.6 KiB
Go
290 lines
7.6 KiB
Go
// Copyright 2012 Google Inc. All rights reserved.
|
|
// Use of this source code is governed by the Apache 2.0
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// +build appengine
|
|
|
|
package socket
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"golang.org/x/net/context"
|
|
"google.golang.org/appengine/internal"
|
|
|
|
pb "google.golang.org/appengine/internal/socket"
|
|
)
|
|
|
|
// Dial connects to the address addr on the network protocol.
|
|
// The address format is host:port, where host may be a hostname or an IP address.
|
|
// Known protocols are "tcp" and "udp".
|
|
// The returned connection satisfies net.Conn, and is valid while ctx is valid;
|
|
// if the connection is to be used after ctx becomes invalid, invoke SetContext
|
|
// with the new context.
|
|
func Dial(ctx context.Context, protocol, addr string) (*Conn, error) {
|
|
return DialTimeout(ctx, protocol, addr, 0)
|
|
}
|
|
|
|
var ipFamilies = []pb.CreateSocketRequest_SocketFamily{
|
|
pb.CreateSocketRequest_IPv4,
|
|
pb.CreateSocketRequest_IPv6,
|
|
}
|
|
|
|
// DialTimeout is like Dial but takes a timeout.
|
|
// The timeout includes name resolution, if required.
|
|
func DialTimeout(ctx context.Context, protocol, addr string, timeout time.Duration) (*Conn, error) {
|
|
dialCtx := ctx // Used for dialing and name resolution, but not stored in the *Conn.
|
|
if timeout > 0 {
|
|
var cancel context.CancelFunc
|
|
dialCtx, cancel = context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
host, portStr, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
port, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("socket: bad port %q: %v", portStr, err)
|
|
}
|
|
|
|
var prot pb.CreateSocketRequest_SocketProtocol
|
|
switch protocol {
|
|
case "tcp":
|
|
prot = pb.CreateSocketRequest_TCP
|
|
case "udp":
|
|
prot = pb.CreateSocketRequest_UDP
|
|
default:
|
|
return nil, fmt.Errorf("socket: unknown protocol %q", protocol)
|
|
}
|
|
|
|
packedAddrs, resolved, err := resolve(dialCtx, ipFamilies, host)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err)
|
|
}
|
|
if len(packedAddrs) == 0 {
|
|
return nil, fmt.Errorf("no addresses for %q", host)
|
|
}
|
|
|
|
packedAddr := packedAddrs[0] // use first address
|
|
fam := pb.CreateSocketRequest_IPv4
|
|
if len(packedAddr) == net.IPv6len {
|
|
fam = pb.CreateSocketRequest_IPv6
|
|
}
|
|
|
|
req := &pb.CreateSocketRequest{
|
|
Family: fam.Enum(),
|
|
Protocol: prot.Enum(),
|
|
RemoteIp: &pb.AddressPort{
|
|
Port: proto.Int32(int32(port)),
|
|
PackedAddress: packedAddr,
|
|
},
|
|
}
|
|
if resolved {
|
|
req.RemoteIp.HostnameHint = &host
|
|
}
|
|
res := &pb.CreateSocketReply{}
|
|
if err := internal.Call(dialCtx, "remote_socket", "CreateSocket", req, res); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Conn{
|
|
ctx: ctx,
|
|
desc: res.GetSocketDescriptor(),
|
|
prot: prot,
|
|
local: res.ProxyExternalIp,
|
|
remote: req.RemoteIp,
|
|
}, nil
|
|
}
|
|
|
|
// LookupIP returns the given host's IP addresses.
|
|
func LookupIP(ctx context.Context, host string) (addrs []net.IP, err error) {
|
|
packedAddrs, _, err := resolve(ctx, ipFamilies, host)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err)
|
|
}
|
|
addrs = make([]net.IP, len(packedAddrs))
|
|
for i, pa := range packedAddrs {
|
|
addrs[i] = net.IP(pa)
|
|
}
|
|
return addrs, nil
|
|
}
|
|
|
|
func resolve(ctx context.Context, fams []pb.CreateSocketRequest_SocketFamily, host string) ([][]byte, bool, error) {
|
|
// Check if it's an IP address.
|
|
if ip := net.ParseIP(host); ip != nil {
|
|
if ip := ip.To4(); ip != nil {
|
|
return [][]byte{ip}, false, nil
|
|
}
|
|
return [][]byte{ip}, false, nil
|
|
}
|
|
|
|
req := &pb.ResolveRequest{
|
|
Name: &host,
|
|
AddressFamilies: fams,
|
|
}
|
|
res := &pb.ResolveReply{}
|
|
if err := internal.Call(ctx, "remote_socket", "Resolve", req, res); err != nil {
|
|
// XXX: need to map to pb.ResolveReply_ErrorCode?
|
|
return nil, false, err
|
|
}
|
|
return res.PackedAddress, true, nil
|
|
}
|
|
|
|
// withDeadline is like context.WithDeadline, except it ignores the zero deadline.
|
|
func withDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) {
|
|
if deadline.IsZero() {
|
|
return parent, func() {}
|
|
}
|
|
return context.WithDeadline(parent, deadline)
|
|
}
|
|
|
|
// Conn represents a socket connection.
|
|
// It implements net.Conn.
|
|
type Conn struct {
|
|
ctx context.Context
|
|
desc string
|
|
offset int64
|
|
|
|
prot pb.CreateSocketRequest_SocketProtocol
|
|
local, remote *pb.AddressPort
|
|
|
|
readDeadline, writeDeadline time.Time // optional
|
|
}
|
|
|
|
// SetContext sets the context that is used by this Conn.
|
|
// It is usually used only when using a Conn that was created in a different context,
|
|
// such as when a connection is created during a warmup request but used while
|
|
// servicing a user request.
|
|
func (cn *Conn) SetContext(ctx context.Context) {
|
|
cn.ctx = ctx
|
|
}
|
|
|
|
func (cn *Conn) Read(b []byte) (n int, err error) {
|
|
const maxRead = 1 << 20
|
|
if len(b) > maxRead {
|
|
b = b[:maxRead]
|
|
}
|
|
|
|
req := &pb.ReceiveRequest{
|
|
SocketDescriptor: &cn.desc,
|
|
DataSize: proto.Int32(int32(len(b))),
|
|
}
|
|
res := &pb.ReceiveReply{}
|
|
if !cn.readDeadline.IsZero() {
|
|
req.TimeoutSeconds = proto.Float64(cn.readDeadline.Sub(time.Now()).Seconds())
|
|
}
|
|
ctx, cancel := withDeadline(cn.ctx, cn.readDeadline)
|
|
defer cancel()
|
|
if err := internal.Call(ctx, "remote_socket", "Receive", req, res); err != nil {
|
|
return 0, err
|
|
}
|
|
if len(res.Data) == 0 {
|
|
return 0, io.EOF
|
|
}
|
|
if len(res.Data) > len(b) {
|
|
return 0, fmt.Errorf("socket: internal error: read too much data: %d > %d", len(res.Data), len(b))
|
|
}
|
|
return copy(b, res.Data), nil
|
|
}
|
|
|
|
func (cn *Conn) Write(b []byte) (n int, err error) {
|
|
const lim = 1 << 20 // max per chunk
|
|
|
|
for n < len(b) {
|
|
chunk := b[n:]
|
|
if len(chunk) > lim {
|
|
chunk = chunk[:lim]
|
|
}
|
|
|
|
req := &pb.SendRequest{
|
|
SocketDescriptor: &cn.desc,
|
|
Data: chunk,
|
|
StreamOffset: &cn.offset,
|
|
}
|
|
res := &pb.SendReply{}
|
|
if !cn.writeDeadline.IsZero() {
|
|
req.TimeoutSeconds = proto.Float64(cn.writeDeadline.Sub(time.Now()).Seconds())
|
|
}
|
|
ctx, cancel := withDeadline(cn.ctx, cn.writeDeadline)
|
|
defer cancel()
|
|
if err = internal.Call(ctx, "remote_socket", "Send", req, res); err != nil {
|
|
// assume zero bytes were sent in this RPC
|
|
break
|
|
}
|
|
n += int(res.GetDataSent())
|
|
cn.offset += int64(res.GetDataSent())
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (cn *Conn) Close() error {
|
|
req := &pb.CloseRequest{
|
|
SocketDescriptor: &cn.desc,
|
|
}
|
|
res := &pb.CloseReply{}
|
|
if err := internal.Call(cn.ctx, "remote_socket", "Close", req, res); err != nil {
|
|
return err
|
|
}
|
|
cn.desc = "CLOSED"
|
|
return nil
|
|
}
|
|
|
|
func addr(prot pb.CreateSocketRequest_SocketProtocol, ap *pb.AddressPort) net.Addr {
|
|
if ap == nil {
|
|
return nil
|
|
}
|
|
switch prot {
|
|
case pb.CreateSocketRequest_TCP:
|
|
return &net.TCPAddr{
|
|
IP: net.IP(ap.PackedAddress),
|
|
Port: int(*ap.Port),
|
|
}
|
|
case pb.CreateSocketRequest_UDP:
|
|
return &net.UDPAddr{
|
|
IP: net.IP(ap.PackedAddress),
|
|
Port: int(*ap.Port),
|
|
}
|
|
}
|
|
panic("unknown protocol " + prot.String())
|
|
}
|
|
|
|
func (cn *Conn) LocalAddr() net.Addr { return addr(cn.prot, cn.local) }
|
|
func (cn *Conn) RemoteAddr() net.Addr { return addr(cn.prot, cn.remote) }
|
|
|
|
func (cn *Conn) SetDeadline(t time.Time) error {
|
|
cn.readDeadline = t
|
|
cn.writeDeadline = t
|
|
return nil
|
|
}
|
|
|
|
func (cn *Conn) SetReadDeadline(t time.Time) error {
|
|
cn.readDeadline = t
|
|
return nil
|
|
}
|
|
|
|
func (cn *Conn) SetWriteDeadline(t time.Time) error {
|
|
cn.writeDeadline = t
|
|
return nil
|
|
}
|
|
|
|
// KeepAlive signals that the connection is still in use.
|
|
// It may be called to prevent the socket being closed due to inactivity.
|
|
func (cn *Conn) KeepAlive() error {
|
|
req := &pb.GetSocketNameRequest{
|
|
SocketDescriptor: &cn.desc,
|
|
}
|
|
res := &pb.GetSocketNameReply{}
|
|
return internal.Call(cn.ctx, "remote_socket", "GetSocketName", req, res)
|
|
}
|
|
|
|
func init() {
|
|
internal.RegisterErrorCodeMap("remote_socket", pb.RemoteSocketServiceError_ErrorCode_name)
|
|
}
|