mirror of
https://github.com/Luzifer/ws-relay.git
synced 2024-11-08 14:20:01 +00:00
93 lines
1.9 KiB
Go
93 lines
1.9 KiB
Go
package main
|
|
|
|
import (
|
|
"path"
|
|
"sync"
|
|
|
|
"github.com/gofrs/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
var pool = newSocketPool()
|
|
|
|
type (
|
|
socketPool struct {
|
|
lock sync.RWMutex
|
|
pool map[string]map[string]*websocket.Conn
|
|
sendQueue *namedLocker
|
|
}
|
|
)
|
|
|
|
func newSocketPool() *socketPool {
|
|
return &socketPool{
|
|
pool: make(map[string]map[string]*websocket.Conn),
|
|
sendQueue: newNamedLocker(),
|
|
}
|
|
}
|
|
|
|
func (s *socketPool) Register(name string, conn *websocket.Conn) (string, func()) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
|
|
connID := uuid.Must(uuid.NewV4()).String()
|
|
|
|
if s.pool[name] == nil {
|
|
s.pool[name] = map[string]*websocket.Conn{}
|
|
}
|
|
|
|
s.pool[name][connID] = conn
|
|
logrus.
|
|
WithFields(logrus.Fields{"id": connID, "socket": name}).
|
|
Info("registered socket")
|
|
|
|
return connID, func() { s.Unregister(name, connID) }
|
|
}
|
|
|
|
func (s *socketPool) Send(name string, msgType int, msg []byte) {
|
|
s.lock.RLock()
|
|
defer s.lock.RUnlock()
|
|
|
|
wg := new(sync.WaitGroup)
|
|
|
|
for id := range s.pool[name] {
|
|
wg.Add(1)
|
|
go s.SendLocked(wg, name, id, msgType, msg)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func (s *socketPool) SendLocked(wg *sync.WaitGroup, name, id string, msgType int, msg []byte) {
|
|
defer wg.Done()
|
|
|
|
s.sendQueue.Lock(path.Join(name, id))
|
|
defer s.sendQueue.Unlock(path.Join(name, id))
|
|
|
|
if err := s.pool[name][id].WriteMessage(msgType, msg); err != nil {
|
|
logrus.
|
|
WithError(err).
|
|
WithFields(logrus.Fields{"id": id, "socket": name}).
|
|
Error("delivering to socket")
|
|
s.Unregister(name, id)
|
|
}
|
|
}
|
|
|
|
func (s *socketPool) Unregister(name, connID string) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
|
|
if s.pool[name] == nil || s.pool[name][connID] == nil {
|
|
return
|
|
}
|
|
|
|
logger := logrus.
|
|
WithFields(logrus.Fields{"id": connID, "socket": name})
|
|
|
|
if err := s.pool[name][connID].Close(); err != nil {
|
|
logger.WithError(err).Error("closing socket connection (leaked fd)")
|
|
}
|
|
delete(s.pool[name], connID)
|
|
|
|
logger.Info("unregistered socket")
|
|
}
|