From 8590ce7d780785504aac857d08781b8eb44b4929 Mon Sep 17 00:00:00 2001 From: a Date: Wed, 15 May 2024 04:28:02 -0500 Subject: [PATCH] noot --- cmd/lain/main.go | 61 ++++++----- pkg/capability/cap_servertime/capability.go | 46 +++++++++ pkg/capability/capabilities.go | 109 ++++++++++++++++++++ pkg/ircclient/client.go | 30 ++++++ pkg/ircconn/conn.go | 36 ++++--- pkg/ircv3/event.go | 36 +++++++ pkg/ircv3/handler.go | 16 ++- pkg/ircv3/message.go | 10 +- plugins/auth/auth.go | 80 +++++++------- plugins/useful/channels.go | 7 +- plugins/useful/pong.go | 10 +- src/lainbot/lainbot.go | 51 +++++++++ 12 files changed, 394 insertions(+), 98 deletions(-) create mode 100644 pkg/capability/cap_servertime/capability.go create mode 100644 pkg/capability/capabilities.go create mode 100644 pkg/ircclient/client.go create mode 100644 pkg/ircv3/event.go create mode 100644 src/lainbot/lainbot.go diff --git a/cmd/lain/main.go b/cmd/lain/main.go index eb11aba..ab14862 100644 --- a/cmd/lain/main.go +++ b/cmd/lain/main.go @@ -7,59 +7,72 @@ import ( "os" "time" - "tuxpa.in/a/irc/plugins/caps/ircmw" - _ "github.com/joho/godotenv/autoload" "github.com/lmittmann/tint" "go.uber.org/fx" "go.uber.org/fx/fxevent" + "tuxpa.in/a/irc/pkg/capability" + "tuxpa.in/a/irc/pkg/capability/cap_servertime" "tuxpa.in/a/irc/pkg/ircconn" "tuxpa.in/a/irc/pkg/ircv3" "tuxpa.in/a/irc/plugins/auth" "tuxpa.in/a/irc/plugins/useful" + "tuxpa.in/a/irc/src/lainbot" ) func exec(log *slog.Logger) error { + tlsConfig := &tls.Config{} - conn, err := tls.Dial("tcp", "irc.libera.chat:6697", tlsConfig) + conn, err := tls.Dial("tcp", "put.gay:6697", tlsConfig) //conn, err := net.Dial("tcp", "irc.libera.chat:6667") if err != nil { return err } - irc := ircconn.New(log, conn, conn) - if err != nil { - return err - } ctx := context.Background() name := os.Getenv("LAIN_NICKNAME") saslPassword := os.Getenv("LAIN_PASSWORD") - c := &ircmw.Capabilities{} + reg := &capability.Registrar{} + reg.Enable(&cap_servertime.Capability{}) + instance := &lainbot.Instance{ + Nuh: &ircv3.NUH{ + Name: name, + User: "lain", + Host: "wired", + }, + } + + reg.Enable(&auth.SaslPlain{ + Username: name, + Password: saslPassword, + }) + reg.Enable(&auth.Nick{Nick: name}) + reg.Enable(&auth.User{Username: "lain", Realname: "lain a", Hostname: "wired", Server: "wired"}) + handler := ircv3.Chain( func(next ircv3.Handler) ircv3.Handler { - return ircv3.HandlerFunc(func(ctx context.Context, w ircv3.MessageWriter, m *ircv3.Message) { - log.Info("in <<", "msg", m.String()) - next.Handle(ctx, w, m) + return ircv3.HandlerFunc(func(w ircv3.MessageWriter, e *ircv3.Event) { + log.Info("in <<", "msg", e.Msg.String()) + next.Handle(w, e) }) }, - c.Middleware, - (&auth.SaslPlain{ - Username: name, - Password: saslPassword, - }).Middleware, - (&auth.Nick{Nick: name}).Middleware, - (&auth.User{Username: "lain", Realname: "lain a", Hostname: "wired", Server: "wired"}).Middleware, - ircmw.CapabilityServerTime, + reg.NewMiddleware(), (&useful.Autojoin{Channels: []string{"#lainmaxxing"}}).Middleware, (&useful.Pong{}).Middleware, - ).Handler(ircv3.HandlerFunc(func(ctx context.Context, w ircv3.MessageWriter, m *ircv3.Message) { - })) - - err = irc.Serve(ctx, handler) + func(h ircv3.Handler) ircv3.Handler { + return ircv3.HandlerFunc(func(w ircv3.MessageWriter, e *ircv3.Event) { + if e.Type != ircv3.EventTypeIRC { + return + } + h.Handle(w, e) + }) + }, + ).Handler(instance) + mw := &ircconn.MessageWriter{R: conn, Log: log} + err = ircconn.Serve(ctx, conn, mw, handler) if err != nil { return err } return nil - } func main() { diff --git a/pkg/capability/cap_servertime/capability.go b/pkg/capability/cap_servertime/capability.go new file mode 100644 index 0000000..2bce24e --- /dev/null +++ b/pkg/capability/cap_servertime/capability.go @@ -0,0 +1,46 @@ +package cap_servertime + +import ( + "strings" + + "tuxpa.in/a/irc/pkg/ircv3" +) + +type Capability struct { + supported bool +} + +func (c *Capability) Middleware(next ircv3.Handler) ircv3.Handler { + return ircv3.HandlerFunc(func(w ircv3.MessageWriter, e *ircv3.Event) { + if c.supported { + // tString := e.Msg.Tags.Get("time") + // if tString != "" { + // parsedTime, err := time.Parse("2006-01-02T15:04:05.000Z", tString) + // if err == nil { + // ctx = context.WithValue(ctx, keyServerTime, parsedTime) + // } + // } + } + next.Handle(w, e) + }) +} + +func (c *Capability) Handle(w ircv3.MessageWriter, e *ircv3.Event) bool { + // dont have the capability? nothing to do. registered. + if e.Msg.Command == "CAP" && e.Msg.Param(0) == "*" && e.Msg.Param(1) == "LS" { + for _, v := range strings.Fields(e.Msg.Param(2)) { + if v == "server-time" { + w.WriteMessage(ircv3.NewMessage("CAP", "REQ", "server-time")) + return true + } + } + } + if e.Msg.Command == "CAP" && e.Msg.Param(2) == "server-time" { + if e.Msg.Param(1) == "ACK" { + c.supported = true + } else { + c.supported = false + } + } + return false +} diff --git a/pkg/capability/capabilities.go b/pkg/capability/capabilities.go new file mode 100644 index 0000000..7e5a50f --- /dev/null +++ b/pkg/capability/capabilities.go @@ -0,0 +1,109 @@ +package capability + +import ( + "context" + "strings" + + "tuxpa.in/a/irc/pkg/ircv3" +) + +var registrarKey struct{} + +type Capability interface { + Handle(w ircv3.MessageWriter, e *ircv3.Event) bool + Middleware(next ircv3.Handler) ircv3.Handler +} + +type Registrar struct { + enabled []Capability + + discovered []string + negotiating int +} + +func (cs *Registrar) Enable(c Capability) error { + cs.enabled = append(cs.enabled, c) + return nil +} + +func (cs *Registrar) NewMiddleware() func(next ircv3.Handler) ircv3.Handler { + waiting := map[Capability]struct{}{} + for _, v := range cs.enabled { + waiting[v] = struct{}{} + } + return func(next ircv3.Handler) ircv3.Handler { + cur := next + for _, v := range cs.enabled { + cur = v.Middleware(next) + } + return ircv3.HandlerFunc(func(w ircv3.MessageWriter, e *ircv3.Event) { + e = e.WithContext(context.WithValue(e.Context(), registrarKey, *cs)) + // reset signal, set negotiating to 1 + if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" { + for k := range waiting { + delete(waiting, k) + } + for _, v := range cs.enabled { + waiting[v] = struct{}{} + } + cs.negotiating = 1 + cs.discovered = nil + } + // done negotiating, so run the handle middleware + if cs.negotiating == 4 { + cur.Handle(w, e) + return + } + // increase negotiating stage when receive the CAP * LS response + if cs.negotiating == 2 && e.Msg.Command == "CAP" && e.Msg.Param(0) == "*" && e.Msg.Param(1) == "LS" { + for _, v := range strings.Fields(e.Msg.Param(2)) { + cs.discovered = append(cs.discovered, v) + } + cs.negotiating = 3 + } + if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" { + w.WriteMessage(ircv3.NewMessage("CAP", "LS", "302")) + cs.negotiating = 2 + } + // run all negotiation handlers + // this allows sasl auth to happen before CAP LS happens. + for v := range waiting { + ready := v.Handle(w, e) + if ready { + delete(waiting, v) + } + } + // not done negotiating yet, so dont run handler middleware + next.Handle(w, e) + + if cs.negotiating == 3 && len(waiting) == 0 { + cs.negotiating = 4 + w.WriteMessage(ircv3.NewMessage("CAP", "END")) + } + }) + } +} + +func NegotiatingState(ctx context.Context) int { + val, ok := ctx.Value(registrarKey).(*Registrar) + if !ok { + return 0 + } + return val.negotiating +} +func IsDoneNegotiating(ctx context.Context) bool { + return NegotiatingState(ctx) >= 3 +} + +func HasCapability(ctx context.Context, c string) bool { + val, ok := ctx.Value(registrarKey).(*Registrar) + if !ok { + return false + } + for _, v := range val.discovered { + if v == c { + return true + } + } + return false +} diff --git a/pkg/ircclient/client.go b/pkg/ircclient/client.go new file mode 100644 index 0000000..65a0dc5 --- /dev/null +++ b/pkg/ircclient/client.go @@ -0,0 +1,30 @@ +package ircclient + +import ( + "context" + "io" + + "tuxpa.in/a/irc/pkg/ircconn" + "tuxpa.in/a/irc/pkg/ircv3" +) + +type Client struct { +} + +func (r *Client) Connect(ctx context.Context, + Connector func() (io.Writer, io.Reader, error), + h ircv3.Handler) error { + wr, rd, err := Connector() + if err != nil { + return err + } + irc := ircconn.New(wr, rd) + if err != nil { + return err + } + err = irc.Serve(ctx, h) + if err != nil { + return err + } + return nil +} diff --git a/pkg/ircconn/conn.go b/pkg/ircconn/conn.go index 422af46..aa885ac 100644 --- a/pkg/ircconn/conn.go +++ b/pkg/ircconn/conn.go @@ -15,25 +15,21 @@ import ( type Conn struct { w io.Writer r io.Reader - log *slog.Logger muWrite sync.Mutex } -func New(log *slog.Logger, w io.Writer, r io.Reader) *Conn { +func New(w io.Writer, r io.Reader) *Conn { return &Conn{ - log: log, - w: w, - r: r, + w: w, + r: r, } } // while serve is running, the conn owns the reader. -func (c *Conn) Serve(ctx context.Context, h ircv3.Handler) error { - // once serve is called, we call with an empty message. - h.Handle(ctx, c, &ircv3.Message{}) +func Serve(ctx context.Context, r io.Reader, wr ircv3.MessageWriter, h ircv3.Handler) error { + h.Handle(wr, ircv3.NewEvent(ctx, ircv3.EventTypeCONTROL, ircv3.NewMessage("/EVENT_ON_SERVE"))) dec := &ircdecoder.Decoder{} - r := c.r - r = bufio.NewReaderSize(c.r, 10240) + r = bufio.NewReaderSize(r, 10240) for { select { case <-ctx.Done(): @@ -46,22 +42,30 @@ func (c *Conn) Serve(ctx context.Context, h ircv3.Handler) error { if err != nil { return err } - h.Handle(ctx, c, msg) + h.Handle(wr, ircv3.NewEvent(ctx, ircv3.EventTypeIRC, msg)) } } -func (c *Conn) WriteMessage(msg *ircv3.Message) error { +type MessageWriter struct { + R io.Writer + Log *slog.Logger + mu sync.Mutex +} + +func (r *MessageWriter) WriteMessage(msg *ircv3.Message) error { b := bytebufferpool.Get() defer bytebufferpool.Put(b) err := msg.Encode(b) if err != nil { return err } + if r.Log != nil { + r.Log.Info("out >", "msg", msg.String()) + } b.WriteString("\r\n") - c.muWrite.Lock() - defer c.muWrite.Unlock() - c.log.Info("out >", "msg", msg.String()) - _, err = b.WriteTo(c.w) + r.mu.Lock() + defer r.mu.Unlock() + _, err = r.R.Write(b.B) if err != nil { return err } diff --git a/pkg/ircv3/event.go b/pkg/ircv3/event.go new file mode 100644 index 0000000..5d6aaf1 --- /dev/null +++ b/pkg/ircv3/event.go @@ -0,0 +1,36 @@ +package ircv3 + +import "context" + +type EventType = string + +const ( + EventTypeIRC EventType = "irc" + EventTypeCONTROL EventType = "control" +) + +type Event struct { + Type EventType + Msg *Message + ctx context.Context +} + +func NewEvent(ctx context.Context, t EventType, msg *Message) *Event { + return &Event{ + ctx: ctx, + Msg: msg, + Type: t, + } +} + +func (e *Event) Context() context.Context { + return e.ctx +} + +func (e *Event) WithContext(ctx context.Context) *Event { + return &Event{ + Msg: e.Msg, + Type: e.Type, + ctx: ctx, + } +} diff --git a/pkg/ircv3/handler.go b/pkg/ircv3/handler.go index e4455f7..62e5b71 100644 --- a/pkg/ircv3/handler.go +++ b/pkg/ircv3/handler.go @@ -1,16 +1,12 @@ package ircv3 -import ( - "context" -) - type Handler interface { - Handle(ctx context.Context, w MessageWriter, m *Message) + Handle(w MessageWriter, m *Event) } -type HandlerFunc func(ctx context.Context, w MessageWriter, m *Message) +type HandlerFunc func(w MessageWriter, e *Event) -func (h HandlerFunc) Handle(ctx context.Context, w MessageWriter, m *Message) { - h(ctx, w, m) +func (h HandlerFunc) Handle(w MessageWriter, e *Event) { + h(w, e) } type MessageWriter interface { @@ -46,8 +42,8 @@ type ChainHandler struct { Middlewares Middlewares } -func (c *ChainHandler) Handle(ctx context.Context, w MessageWriter, m *Message) { - c.chain.Handle(ctx, w, m) +func (c *ChainHandler) Handle(w MessageWriter, e *Event) { + c.chain.Handle(w, e) } // chain builds a http.Handler composed of an inline middleware stack and endpoint diff --git a/pkg/ircv3/message.go b/pkg/ircv3/message.go index fc600ec..f6f89d4 100644 --- a/pkg/ircv3/message.go +++ b/pkg/ircv3/message.go @@ -25,6 +25,7 @@ func (m *Message) SetSource(nuh *NUH) *Message { m.Source = nuh return m } + func (m *Message) Param(i int) string { if len(m.Params) > i { return m.Params[i] @@ -52,6 +53,9 @@ func (msg *Message) Encode(w io.Writer) error { } } if msg.Source != nil { + if _, err := w.Write([]byte(":")); err != nil { + return err + } if _, err := w.Write([]byte(msg.Source.String())); err != nil { return err } @@ -63,8 +67,10 @@ func (msg *Message) Encode(w io.Writer) error { if _, err := w.Write([]byte(msg.Command)); err != nil { return err } - if _, err := w.Write([]byte(" ")); err != nil { - return err + if len(msg.Params) > 0 { + if _, err := w.Write([]byte(" ")); err != nil { + return err + } } for idx, v := range msg.Params { if idx != 0 { diff --git a/plugins/auth/auth.go b/plugins/auth/auth.go index b1cb94f..9e7d31a 100644 --- a/plugins/auth/auth.go +++ b/plugins/auth/auth.go @@ -1,9 +1,7 @@ package auth import ( - "context" "encoding/base64" - "tuxpa.in/a/irc/plugins/caps/ircmw" "tuxpa.in/a/irc/pkg/ircv3" ) @@ -13,28 +11,33 @@ type SaslPlain struct { Password string } -func (saslplain *SaslPlain) Middleware(next ircv3.Handler) ircv3.Handler { - return ircv3.HandlerFunc(func(ctx context.Context, w ircv3.MessageWriter, m *ircv3.Message) { - if m.Command == "" { - ircmw.AddPending(ctx, 1) - w.WriteMessage(ircv3.NewMessage("CAP", "REQ", "sasl")) - } - if m.Command == "CAP" && m.Param(0) == "*" && m.Param(1) == "ACK" && m.Param(2) == "sasl" { - w.WriteMessage(ircv3.NewMessage("AUTHENTICATE", "PLAIN")) - } - if m.Command == "AUTHENTICATE" && m.Param(0) == "+" { - w.WriteMessage(ircv3.NewMessage("AUTHENTICATE", base64.StdEncoding.EncodeToString([]byte( +func (saslplain *SaslPlain) Handle(w ircv3.MessageWriter, e *ircv3.Event) bool { + if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" { + w.WriteMessage(ircv3.NewMessage("CAP", "REQ", "sasl")) + return false + } + m, _ := e.Msg, e.Context() + if m.Command == "CAP" && m.Param(0) == "*" && m.Param(1) == "ACK" && m.Param(2) == "sasl" { + w.WriteMessage(ircv3.NewMessage("AUTHENTICATE", "PLAIN")) + return false + } + if m.Command == "AUTHENTICATE" && m.Param(0) == "+" { + w.WriteMessage(ircv3.NewMessage("AUTHENTICATE", base64.StdEncoding.EncodeToString([]byte( + saslplain.Username+string([]byte{0})+ saslplain.Username+string([]byte{0})+ - saslplain.Username+string([]byte{0})+ - saslplain.Password, - )))) - } - switch m.Command { - case "903", "904", "905", "906", "907": - ircmw.AddPending(ctx, -1) - } - next.Handle(ctx, w, m) - }) + saslplain.Password, + )))) + return false + } + switch m.Command { + case "903", "904", "905", "906", "907": + return true + } + return false +} +func (saslplain *SaslPlain) Middleware(next ircv3.Handler) ircv3.Handler { + // nothing to do + return next } type User struct { @@ -44,24 +47,29 @@ type User struct { Server string } +func (u *User) Handle(w ircv3.MessageWriter, e *ircv3.Event) bool { + if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" { + w.WriteMessage(ircv3.NewMessage("USER", u.Username, u.Hostname, u.Server, u.Realname)) + } + return true +} + func (u *User) Middleware(next ircv3.Handler) ircv3.Handler { - return ircv3.HandlerFunc(func(ctx context.Context, w ircv3.MessageWriter, m *ircv3.Message) { - if m.Command == "" { - w.WriteMessage(ircv3.NewMessage("USER", u.Username, u.Hostname, u.Server, u.Realname)) - } - next.Handle(ctx, w, m) - }) + // nothing to do + return next } type Nick struct { Nick string } -func (u *Nick) Middleware(next ircv3.Handler) ircv3.Handler { - return ircv3.HandlerFunc(func(ctx context.Context, w ircv3.MessageWriter, m *ircv3.Message) { - if m.Command == "" { - w.WriteMessage(ircv3.NewMessage("NICK", u.Nick)) - } - next.Handle(ctx, w, m) - }) +func (u *Nick) Handle(w ircv3.MessageWriter, e *ircv3.Event) bool { + if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" { + w.WriteMessage(ircv3.NewMessage("NICK", u.Nick)) + } + return true +} + +func (u *Nick) Middleware(next ircv3.Handler) ircv3.Handler { + return next } diff --git a/plugins/useful/channels.go b/plugins/useful/channels.go index fb248ce..cb9d36f 100644 --- a/plugins/useful/channels.go +++ b/plugins/useful/channels.go @@ -1,7 +1,6 @@ package useful import ( - "context" "strings" "tuxpa.in/a/irc/pkg/ircv3" @@ -12,10 +11,10 @@ type Autojoin struct { } func (u *Autojoin) Middleware(next ircv3.Handler) ircv3.Handler { - return ircv3.HandlerFunc(func(ctx context.Context, w ircv3.MessageWriter, m *ircv3.Message) { - if m.Command == "005" { + return ircv3.HandlerFunc(func(w ircv3.MessageWriter, e *ircv3.Event) { + if e.Msg.Command == "005" { w.WriteMessage(ircv3.NewMessage("JOIN", strings.Join(u.Channels, ","))) } - next.Handle(ctx, w, m) + next.Handle(w, e) }) } diff --git a/plugins/useful/pong.go b/plugins/useful/pong.go index 7385676..455a196 100644 --- a/plugins/useful/pong.go +++ b/plugins/useful/pong.go @@ -1,8 +1,6 @@ package useful import ( - "context" - "tuxpa.in/a/irc/pkg/ircv3" ) @@ -10,10 +8,10 @@ type Pong struct { } func (u *Pong) Middleware(next ircv3.Handler) ircv3.Handler { - return ircv3.HandlerFunc(func(ctx context.Context, w ircv3.MessageWriter, m *ircv3.Message) { - if m.Command == "PING" { - w.WriteMessage(ircv3.NewMessage("PONG", m.Param(0))) + return ircv3.HandlerFunc(func(w ircv3.MessageWriter, e *ircv3.Event) { + if e.Msg.Command == "PING" { + w.WriteMessage(ircv3.NewMessage("PONG", e.Msg.Param(0))) } - next.Handle(ctx, w, m) + next.Handle(w, e) }) } diff --git a/src/lainbot/lainbot.go b/src/lainbot/lainbot.go new file mode 100644 index 0000000..d477071 --- /dev/null +++ b/src/lainbot/lainbot.go @@ -0,0 +1,51 @@ +package lainbot + +import ( + "strings" + + "tuxpa.in/a/irc/pkg/ircv3" +) + +type Instance struct { + Nuh *ircv3.NUH +} + +type messageWriter struct { + mw ircv3.MessageWriter + nuh *ircv3.NUH +} + +func (messagewriter *messageWriter) WriteMessage(msg *ircv3.Message) error { + if msg.Command == "PRIVMSG" { + msg.SetSource(messagewriter.nuh) + } + return messagewriter.mw.WriteMessage(msg) +} + +func (instance *Instance) Handle(w ircv3.MessageWriter, e *ircv3.Event) { + instance.handle(&messageWriter{mw: w, nuh: instance.Nuh}, e) +} +func (instance *Instance) handle(w ircv3.MessageWriter, e *ircv3.Event) { + m := e.Msg + if m.Command != "PRIVMSG" { + return + } + if m.Param(0) != "#lainmaxxing" { + return + } + + messageContent := m.Param(1) + if strings.HasPrefix(messageContent, "!") { + instance.handle(w, e) + } + return +} +func (instance *Instance) handleCommand(w ircv3.MessageWriter, e *ircv3.Event) { + m := e.Msg + messageContent := m.Param(1) + if messageContent == "!ping" { + w.WriteMessage(ircv3.NewMessage("PRIVMSG", m.Param(0), "pong")) + return + } + return +}