diff --git a/irc.go b/irc.go index 888befb..f8eed77 100644 --- a/irc.go +++ b/irc.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "fmt" "math" @@ -45,9 +46,11 @@ func registerRawMessageHandler(fn plugins.RawMessageHandlerFunc) error { } type ircHandler struct { - conn *tls.Conn - c *irc.Client - user string + c *irc.Client + conn *tls.Conn + ctx context.Context + ctxCancelFn func() + user string } func newIRCHandler() (*ircHandler, error) { @@ -58,6 +61,8 @@ func newIRCHandler() (*ircHandler, error) { return nil, errors.Wrap(err, "fetching username") } + h.ctx, h.ctxCancelFn = context.WithCancel(context.Background()) + conn, err := tls.Dial("tcp", "irc.chat.twitch.tv:6697", nil) if err != nil { return nil, errors.Wrap(err, "connect to IRC server") @@ -86,7 +91,10 @@ func newIRCHandler() (*ircHandler, error) { func (i ircHandler) Client() *irc.Client { return i.c } -func (i ircHandler) Close() error { return i.conn.Close() } +func (i ircHandler) Close() error { + i.ctxCancelFn() + return nil +} func (i ircHandler) ExecuteJoins(channels []string) { for _, ch := range channels { @@ -197,7 +205,7 @@ func (i ircHandler) Handle(c *irc.Client, m *irc.Message) { } } -func (i ircHandler) Run() error { return errors.Wrap(i.c.Run(), "running IRC client") } +func (i ircHandler) Run() error { return errors.Wrap(i.c.RunContext(i.ctx), "running IRC client") } func (i ircHandler) SendMessage(m *irc.Message) error { return i.c.WriteMessage(m) }