mirror of
https://github.com/Luzifer/nginx-sso.git
synced 2024-12-21 05:11:17 +00:00
242 lines
5.6 KiB
Go
242 lines
5.6 KiB
Go
|
// Copyright 2018 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
// Package sockstest provides utilities for SOCKS testing.
|
||
|
package sockstest
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"io"
|
||
|
"net"
|
||
|
|
||
|
"golang.org/x/net/internal/socks"
|
||
|
"golang.org/x/net/nettest"
|
||
|
)
|
||
|
|
||
|
// An AuthRequest represents an authentication request.
|
||
|
type AuthRequest struct {
|
||
|
Version int
|
||
|
Methods []socks.AuthMethod
|
||
|
}
|
||
|
|
||
|
// ParseAuthRequest parses an authentication request.
|
||
|
func ParseAuthRequest(b []byte) (*AuthRequest, error) {
|
||
|
if len(b) < 2 {
|
||
|
return nil, errors.New("short auth request")
|
||
|
}
|
||
|
if b[0] != socks.Version5 {
|
||
|
return nil, errors.New("unexpected protocol version")
|
||
|
}
|
||
|
if len(b)-2 < int(b[1]) {
|
||
|
return nil, errors.New("short auth request")
|
||
|
}
|
||
|
req := &AuthRequest{Version: int(b[0])}
|
||
|
if b[1] > 0 {
|
||
|
req.Methods = make([]socks.AuthMethod, b[1])
|
||
|
for i, m := range b[2 : 2+b[1]] {
|
||
|
req.Methods[i] = socks.AuthMethod(m)
|
||
|
}
|
||
|
}
|
||
|
return req, nil
|
||
|
}
|
||
|
|
||
|
// MarshalAuthReply returns an authentication reply in wire format.
|
||
|
func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) {
|
||
|
return []byte{byte(ver), byte(m)}, nil
|
||
|
}
|
||
|
|
||
|
// A CmdRequest repesents a command request.
|
||
|
type CmdRequest struct {
|
||
|
Version int
|
||
|
Cmd socks.Command
|
||
|
Addr socks.Addr
|
||
|
}
|
||
|
|
||
|
// ParseCmdRequest parses a command request.
|
||
|
func ParseCmdRequest(b []byte) (*CmdRequest, error) {
|
||
|
if len(b) < 7 {
|
||
|
return nil, errors.New("short cmd request")
|
||
|
}
|
||
|
if b[0] != socks.Version5 {
|
||
|
return nil, errors.New("unexpected protocol version")
|
||
|
}
|
||
|
if socks.Command(b[1]) != socks.CmdConnect {
|
||
|
return nil, errors.New("unexpected command")
|
||
|
}
|
||
|
if b[2] != 0 {
|
||
|
return nil, errors.New("non-zero reserved field")
|
||
|
}
|
||
|
req := &CmdRequest{Version: int(b[0]), Cmd: socks.Command(b[1])}
|
||
|
l := 2
|
||
|
off := 4
|
||
|
switch b[3] {
|
||
|
case socks.AddrTypeIPv4:
|
||
|
l += net.IPv4len
|
||
|
req.Addr.IP = make(net.IP, net.IPv4len)
|
||
|
case socks.AddrTypeIPv6:
|
||
|
l += net.IPv6len
|
||
|
req.Addr.IP = make(net.IP, net.IPv6len)
|
||
|
case socks.AddrTypeFQDN:
|
||
|
l += int(b[4])
|
||
|
off = 5
|
||
|
default:
|
||
|
return nil, errors.New("unknown address type")
|
||
|
}
|
||
|
if len(b[off:]) < l {
|
||
|
return nil, errors.New("short cmd request")
|
||
|
}
|
||
|
if req.Addr.IP != nil {
|
||
|
copy(req.Addr.IP, b[off:])
|
||
|
} else {
|
||
|
req.Addr.Name = string(b[off : off+l-2])
|
||
|
}
|
||
|
req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1])
|
||
|
return req, nil
|
||
|
}
|
||
|
|
||
|
// MarshalCmdReply returns a command reply in wire format.
|
||
|
func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) {
|
||
|
b := make([]byte, 4)
|
||
|
b[0] = byte(ver)
|
||
|
b[1] = byte(reply)
|
||
|
if a.Name != "" {
|
||
|
if len(a.Name) > 255 {
|
||
|
return nil, errors.New("fqdn too long")
|
||
|
}
|
||
|
b[3] = socks.AddrTypeFQDN
|
||
|
b = append(b, byte(len(a.Name)))
|
||
|
b = append(b, a.Name...)
|
||
|
} else if ip4 := a.IP.To4(); ip4 != nil {
|
||
|
b[3] = socks.AddrTypeIPv4
|
||
|
b = append(b, ip4...)
|
||
|
} else if ip6 := a.IP.To16(); ip6 != nil {
|
||
|
b[3] = socks.AddrTypeIPv6
|
||
|
b = append(b, ip6...)
|
||
|
} else {
|
||
|
return nil, errors.New("unknown address type")
|
||
|
}
|
||
|
b = append(b, byte(a.Port>>8), byte(a.Port))
|
||
|
return b, nil
|
||
|
}
|
||
|
|
||
|
// A Server repesents a server for handshake testing.
|
||
|
type Server struct {
|
||
|
ln net.Listener
|
||
|
}
|
||
|
|
||
|
// Addr rerurns a server address.
|
||
|
func (s *Server) Addr() net.Addr {
|
||
|
return s.ln.Addr()
|
||
|
}
|
||
|
|
||
|
// TargetAddr returns a fake final destination address.
|
||
|
//
|
||
|
// The returned address is only valid for testing with Server.
|
||
|
func (s *Server) TargetAddr() net.Addr {
|
||
|
a := s.ln.Addr()
|
||
|
switch a := a.(type) {
|
||
|
case *net.TCPAddr:
|
||
|
if a.IP.To4() != nil {
|
||
|
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963}
|
||
|
}
|
||
|
if a.IP.To16() != nil && a.IP.To4() == nil {
|
||
|
return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963}
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Close closes the server.
|
||
|
func (s *Server) Close() error {
|
||
|
return s.ln.Close()
|
||
|
}
|
||
|
|
||
|
func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) {
|
||
|
c, err := s.ln.Accept()
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
defer c.Close()
|
||
|
go s.serve(authFunc, cmdFunc)
|
||
|
b := make([]byte, 512)
|
||
|
n, err := c.Read(b)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
if err := authFunc(c, b[:n]); err != nil {
|
||
|
return
|
||
|
}
|
||
|
n, err = c.Read(b)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
if err := cmdFunc(c, b[:n]); err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// NewServer returns a new server.
|
||
|
//
|
||
|
// The provided authFunc and cmdFunc must parse requests and return
|
||
|
// appropriate replies to clients.
|
||
|
func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) {
|
||
|
var err error
|
||
|
s := new(Server)
|
||
|
s.ln, err = nettest.NewLocalListener("tcp")
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
go s.serve(authFunc, cmdFunc)
|
||
|
return s, nil
|
||
|
}
|
||
|
|
||
|
// NoAuthRequired handles a no-authentication-required signaling.
|
||
|
func NoAuthRequired(rw io.ReadWriter, b []byte) error {
|
||
|
req, err := ParseAuthRequest(b)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
b, err = MarshalAuthReply(req.Version, socks.AuthMethodNotRequired)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
n, err := rw.Write(b)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if n != len(b) {
|
||
|
return errors.New("short write")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// NoProxyRequired handles a command signaling without constructing a
|
||
|
// proxy connection to the final destination.
|
||
|
func NoProxyRequired(rw io.ReadWriter, b []byte) error {
|
||
|
req, err := ParseCmdRequest(b)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
req.Addr.Port += 1
|
||
|
if req.Addr.Name != "" {
|
||
|
req.Addr.Name = "boundaddr.doesnotexist"
|
||
|
} else if req.Addr.IP.To4() != nil {
|
||
|
req.Addr.IP = net.IPv4(127, 0, 0, 1)
|
||
|
} else {
|
||
|
req.Addr.IP = net.IPv6loopback
|
||
|
}
|
||
|
b, err = MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &req.Addr)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
n, err := rw.Write(b)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if n != len(b) {
|
||
|
return errors.New("short write")
|
||
|
}
|
||
|
return nil
|
||
|
}
|