mirror of
https://github.com/Luzifer/culmqtt.git
synced 2024-11-14 00:32:51 +00:00
336 lines
8.9 KiB
Go
336 lines
8.9 KiB
Go
|
/*
|
||
|
* Copyright (c) 2013 IBM Corp.
|
||
|
*
|
||
|
* All rights reserved. This program and the accompanying materials
|
||
|
* are made available under the terms of the Eclipse Public License v1.0
|
||
|
* which accompanies this distribution, and is available at
|
||
|
* http://www.eclipse.org/legal/epl-v10.html
|
||
|
*
|
||
|
* Contributors:
|
||
|
* Seth Hoenig
|
||
|
* Allan Stockdill-Mander
|
||
|
* Mike Robertson
|
||
|
*/
|
||
|
|
||
|
package mqtt
|
||
|
|
||
|
import (
|
||
|
"crypto/tls"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"net/url"
|
||
|
"os"
|
||
|
"reflect"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/eclipse/paho.mqtt.golang/packets"
|
||
|
"golang.org/x/net/proxy"
|
||
|
"golang.org/x/net/websocket"
|
||
|
)
|
||
|
|
||
|
func signalError(c chan<- error, err error) {
|
||
|
select {
|
||
|
case c <- err:
|
||
|
default:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration) (net.Conn, error) {
|
||
|
switch uri.Scheme {
|
||
|
case "ws":
|
||
|
conn, err := websocket.Dial(uri.String(), "mqtt", fmt.Sprintf("http://%s", uri.Host))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
conn.PayloadType = websocket.BinaryFrame
|
||
|
return conn, err
|
||
|
case "wss":
|
||
|
config, _ := websocket.NewConfig(uri.String(), fmt.Sprintf("https://%s", uri.Host))
|
||
|
config.Protocol = []string{"mqtt"}
|
||
|
config.TlsConfig = tlsc
|
||
|
conn, err := websocket.DialConfig(config)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
conn.PayloadType = websocket.BinaryFrame
|
||
|
return conn, err
|
||
|
case "tcp":
|
||
|
allProxy := os.Getenv("all_proxy")
|
||
|
if len(allProxy) == 0 {
|
||
|
conn, err := net.DialTimeout("tcp", uri.Host, timeout)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return conn, nil
|
||
|
}
|
||
|
proxyDialer := proxy.FromEnvironment()
|
||
|
|
||
|
conn, err := proxyDialer.Dial("tcp", uri.Host)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return conn, nil
|
||
|
case "unix":
|
||
|
conn, err := net.DialTimeout("unix", uri.Host, timeout)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return conn, nil
|
||
|
case "ssl":
|
||
|
fallthrough
|
||
|
case "tls":
|
||
|
fallthrough
|
||
|
case "tcps":
|
||
|
allProxy := os.Getenv("all_proxy")
|
||
|
if len(allProxy) == 0 {
|
||
|
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout}, "tcp", uri.Host, tlsc)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return conn, nil
|
||
|
}
|
||
|
proxyDialer := proxy.FromEnvironment()
|
||
|
|
||
|
conn, err := proxyDialer.Dial("tcp", uri.Host)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
tlsConn := tls.Client(conn, tlsc)
|
||
|
|
||
|
err = tlsConn.Handshake()
|
||
|
if err != nil {
|
||
|
conn.Close()
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return tlsConn, nil
|
||
|
}
|
||
|
return nil, errors.New("Unknown protocol")
|
||
|
}
|
||
|
|
||
|
// actually read incoming messages off the wire
|
||
|
// send Message object into ibound channel
|
||
|
func incoming(c *client) {
|
||
|
var err error
|
||
|
var cp packets.ControlPacket
|
||
|
|
||
|
defer c.workers.Done()
|
||
|
|
||
|
DEBUG.Println(NET, "incoming started")
|
||
|
|
||
|
for {
|
||
|
if cp, err = packets.ReadPacket(c.conn); err != nil {
|
||
|
break
|
||
|
}
|
||
|
DEBUG.Println(NET, "Received Message")
|
||
|
select {
|
||
|
case c.ibound <- cp:
|
||
|
// Notify keepalive logic that we recently received a packet
|
||
|
if c.options.KeepAlive != 0 {
|
||
|
atomic.StoreInt64(&c.lastReceived, time.Now().Unix())
|
||
|
}
|
||
|
case <-c.stop:
|
||
|
// This avoids a deadlock should a message arrive while shutting down.
|
||
|
// In that case the "reader" of c.ibound might already be gone
|
||
|
WARN.Println(NET, "incoming dropped a received message during shutdown")
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
// We received an error on read.
|
||
|
// If disconnect is in progress, swallow error and return
|
||
|
select {
|
||
|
case <-c.stop:
|
||
|
DEBUG.Println(NET, "incoming stopped")
|
||
|
return
|
||
|
// Not trying to disconnect, send the error to the errors channel
|
||
|
default:
|
||
|
ERROR.Println(NET, "incoming stopped with error", err)
|
||
|
signalError(c.errors, err)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// receive a Message object on obound, and then
|
||
|
// actually send outgoing message to the wire
|
||
|
func outgoing(c *client) {
|
||
|
defer c.workers.Done()
|
||
|
DEBUG.Println(NET, "outgoing started")
|
||
|
|
||
|
for {
|
||
|
DEBUG.Println(NET, "outgoing waiting for an outbound message")
|
||
|
select {
|
||
|
case <-c.stop:
|
||
|
DEBUG.Println(NET, "outgoing stopped")
|
||
|
return
|
||
|
case pub := <-c.obound:
|
||
|
msg := pub.p.(*packets.PublishPacket)
|
||
|
|
||
|
if c.options.WriteTimeout > 0 {
|
||
|
c.conn.SetWriteDeadline(time.Now().Add(c.options.WriteTimeout))
|
||
|
}
|
||
|
|
||
|
if err := msg.Write(c.conn); err != nil {
|
||
|
ERROR.Println(NET, "outgoing stopped with error", err)
|
||
|
pub.t.setError(err)
|
||
|
signalError(c.errors, err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if c.options.WriteTimeout > 0 {
|
||
|
// If we successfully wrote, we don't want the timeout to happen during an idle period
|
||
|
// so we reset it to infinite.
|
||
|
c.conn.SetWriteDeadline(time.Time{})
|
||
|
}
|
||
|
|
||
|
if msg.Qos == 0 {
|
||
|
pub.t.flowComplete()
|
||
|
}
|
||
|
DEBUG.Println(NET, "obound wrote msg, id:", msg.MessageID)
|
||
|
case msg := <-c.oboundP:
|
||
|
switch msg.p.(type) {
|
||
|
case *packets.SubscribePacket:
|
||
|
msg.p.(*packets.SubscribePacket).MessageID = c.getID(msg.t)
|
||
|
case *packets.UnsubscribePacket:
|
||
|
msg.p.(*packets.UnsubscribePacket).MessageID = c.getID(msg.t)
|
||
|
}
|
||
|
DEBUG.Println(NET, "obound priority msg to write, type", reflect.TypeOf(msg.p))
|
||
|
if err := msg.p.Write(c.conn); err != nil {
|
||
|
ERROR.Println(NET, "outgoing stopped with error", err)
|
||
|
msg.t.setError(err)
|
||
|
signalError(c.errors, err)
|
||
|
return
|
||
|
}
|
||
|
switch msg.p.(type) {
|
||
|
case *packets.DisconnectPacket:
|
||
|
msg.t.(*DisconnectToken).flowComplete()
|
||
|
DEBUG.Println(NET, "outbound wrote disconnect, stopping")
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
// Reset ping timer after sending control packet.
|
||
|
if c.options.KeepAlive != 0 {
|
||
|
atomic.StoreInt64(&c.lastSent, time.Now().Unix())
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// receive Message objects on ibound
|
||
|
// store messages if necessary
|
||
|
// send replies on obound
|
||
|
// delete messages from store if necessary
|
||
|
func alllogic(c *client) {
|
||
|
defer c.workers.Done()
|
||
|
DEBUG.Println(NET, "logic started")
|
||
|
|
||
|
for {
|
||
|
DEBUG.Println(NET, "logic waiting for msg on ibound")
|
||
|
|
||
|
select {
|
||
|
case msg := <-c.ibound:
|
||
|
DEBUG.Println(NET, "logic got msg on ibound")
|
||
|
persistInbound(c.persist, msg)
|
||
|
switch m := msg.(type) {
|
||
|
case *packets.PingrespPacket:
|
||
|
DEBUG.Println(NET, "received pingresp")
|
||
|
atomic.StoreInt32(&c.pingOutstanding, 0)
|
||
|
case *packets.SubackPacket:
|
||
|
DEBUG.Println(NET, "received suback, id:", m.MessageID)
|
||
|
token := c.getToken(m.MessageID)
|
||
|
switch t := token.(type) {
|
||
|
case *SubscribeToken:
|
||
|
DEBUG.Println(NET, "granted qoss", m.ReturnCodes)
|
||
|
for i, qos := range m.ReturnCodes {
|
||
|
t.subResult[t.subs[i]] = qos
|
||
|
}
|
||
|
}
|
||
|
token.flowComplete()
|
||
|
c.freeID(m.MessageID)
|
||
|
case *packets.UnsubackPacket:
|
||
|
DEBUG.Println(NET, "received unsuback, id:", m.MessageID)
|
||
|
c.getToken(m.MessageID).flowComplete()
|
||
|
c.freeID(m.MessageID)
|
||
|
case *packets.PublishPacket:
|
||
|
DEBUG.Println(NET, "received publish, msgId:", m.MessageID)
|
||
|
DEBUG.Println(NET, "putting msg on onPubChan")
|
||
|
switch m.Qos {
|
||
|
case 2:
|
||
|
c.incomingPubChan <- m
|
||
|
DEBUG.Println(NET, "done putting msg on incomingPubChan")
|
||
|
pr := packets.NewControlPacket(packets.Pubrec).(*packets.PubrecPacket)
|
||
|
pr.MessageID = m.MessageID
|
||
|
DEBUG.Println(NET, "putting pubrec msg on obound")
|
||
|
select {
|
||
|
case c.oboundP <- &PacketAndToken{p: pr, t: nil}:
|
||
|
case <-c.stop:
|
||
|
}
|
||
|
DEBUG.Println(NET, "done putting pubrec msg on obound")
|
||
|
case 1:
|
||
|
c.incomingPubChan <- m
|
||
|
DEBUG.Println(NET, "done putting msg on incomingPubChan")
|
||
|
pa := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
|
||
|
pa.MessageID = m.MessageID
|
||
|
DEBUG.Println(NET, "putting puback msg on obound")
|
||
|
persistOutbound(c.persist, pa)
|
||
|
select {
|
||
|
case c.oboundP <- &PacketAndToken{p: pa, t: nil}:
|
||
|
case <-c.stop:
|
||
|
}
|
||
|
DEBUG.Println(NET, "done putting puback msg on obound")
|
||
|
case 0:
|
||
|
select {
|
||
|
case c.incomingPubChan <- m:
|
||
|
case <-c.stop:
|
||
|
}
|
||
|
DEBUG.Println(NET, "done putting msg on incomingPubChan")
|
||
|
}
|
||
|
case *packets.PubackPacket:
|
||
|
DEBUG.Println(NET, "received puback, id:", m.MessageID)
|
||
|
// c.receipts.get(msg.MsgId()) <- Receipt{}
|
||
|
// c.receipts.end(msg.MsgId())
|
||
|
c.getToken(m.MessageID).flowComplete()
|
||
|
c.freeID(m.MessageID)
|
||
|
case *packets.PubrecPacket:
|
||
|
DEBUG.Println(NET, "received pubrec, id:", m.MessageID)
|
||
|
prel := packets.NewControlPacket(packets.Pubrel).(*packets.PubrelPacket)
|
||
|
prel.MessageID = m.MessageID
|
||
|
select {
|
||
|
case c.oboundP <- &PacketAndToken{p: prel, t: nil}:
|
||
|
case <-c.stop:
|
||
|
}
|
||
|
case *packets.PubrelPacket:
|
||
|
DEBUG.Println(NET, "received pubrel, id:", m.MessageID)
|
||
|
pc := packets.NewControlPacket(packets.Pubcomp).(*packets.PubcompPacket)
|
||
|
pc.MessageID = m.MessageID
|
||
|
persistOutbound(c.persist, pc)
|
||
|
select {
|
||
|
case c.oboundP <- &PacketAndToken{p: pc, t: nil}:
|
||
|
case <-c.stop:
|
||
|
}
|
||
|
case *packets.PubcompPacket:
|
||
|
DEBUG.Println(NET, "received pubcomp, id:", m.MessageID)
|
||
|
c.getToken(m.MessageID).flowComplete()
|
||
|
c.freeID(m.MessageID)
|
||
|
}
|
||
|
case <-c.stop:
|
||
|
WARN.Println(NET, "logic stopped")
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func errorWatch(c *client) {
|
||
|
defer c.workers.Done()
|
||
|
select {
|
||
|
case <-c.stop:
|
||
|
WARN.Println(NET, "errorWatch stopped")
|
||
|
return
|
||
|
case err := <-c.errors:
|
||
|
ERROR.Println(NET, "error triggered, stopping")
|
||
|
go c.internalConnLost(err)
|
||
|
return
|
||
|
}
|
||
|
}
|