ws-relay/main.go

113 lines
2.6 KiB
Go
Raw Normal View History

2022-05-21 13:21:51 +00:00
package main
import (
"fmt"
"net/http"
"os"
"strings"
"time"
2022-05-21 13:21:51 +00:00
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/Luzifer/rconfig/v2"
)
const websocketBufferSize = 1024
2022-05-21 13:21:51 +00:00
var (
cfg = struct {
Listen string `flag:"listen" default:":3000" description:"Port/IP to listen on"`
LogLevel string `flag:"log-level" default:"info" description:"Log level (debug, info, warn, error, fatal)"`
VersionAndExit bool `flag:"version" default:"false" description:"Prints current version and exits"`
}{}
upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
ReadBufferSize: websocketBufferSize,
WriteBufferSize: websocketBufferSize,
2022-05-21 13:21:51 +00:00
}
version = "dev"
)
func initApp() error {
rconfig.AutoEnv(true)
if err := rconfig.ParseAndValidate(&cfg); err != nil {
return errors.Wrap(err, "parsing cli options")
}
l, err := logrus.ParseLevel(cfg.LogLevel)
if err != nil {
return errors.Wrap(err, "parsing log-level")
}
logrus.SetLevel(l)
return nil
}
func main() {
var err error
if err = initApp(); err != nil {
logrus.WithError(err).Fatal("initializing app")
}
if cfg.VersionAndExit {
fmt.Printf("ws-relay %s\n", version) //nolint:forbidigo
os.Exit(0)
}
2022-05-21 13:21:51 +00:00
logrus.WithField("version", version).Info("ws-relay started")
router := mux.NewRouter()
router.HandleFunc("/{socket}", handleSocketRelay)
server := &http.Server{
Addr: cfg.Listen,
Handler: router,
ReadHeaderTimeout: time.Second,
}
if err = server.ListenAndServe(); err != nil {
logrus.WithError(err).Fatal("running HTTP server")
2022-05-21 13:21:51 +00:00
}
}
func handleSocketRelay(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") {
// That's no socket request, don't spam the logs
http.Error(w, "this is a socket", http.StatusBadRequest)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logrus.WithError(err).Error("upgrading socket")
return
}
defer func() {
if err := conn.Close(); err != nil {
logrus.WithError(err).Error("closing socket connection (leaked fd)")
}
}()
2022-05-21 13:21:51 +00:00
var (
socketName = mux.Vars(r)["socket"]
connID, unregister = pool.Register(socketName, conn)
logger = logrus.WithFields(logrus.Fields{"id": connID, "socket": socketName})
)
defer unregister()
for {
msgType, msg, err := conn.ReadMessage()
if err != nil {
logger.WithError(err).Error("reading from connection")
return
}
pool.Send(socketName, msgType, msg)
}
}