irc/pkg/capability/capabilities.go
2024-05-15 04:28:02 -05:00

110 lines
2.6 KiB
Go

package capability
import (
"context"
"strings"
"tuxpa.in/a/irc/pkg/ircv3"
)
var registrarKey struct{}
type Capability interface {
Handle(w ircv3.MessageWriter, e *ircv3.Event) bool
Middleware(next ircv3.Handler) ircv3.Handler
}
type Registrar struct {
enabled []Capability
discovered []string
negotiating int
}
func (cs *Registrar) Enable(c Capability) error {
cs.enabled = append(cs.enabled, c)
return nil
}
func (cs *Registrar) NewMiddleware() func(next ircv3.Handler) ircv3.Handler {
waiting := map[Capability]struct{}{}
for _, v := range cs.enabled {
waiting[v] = struct{}{}
}
return func(next ircv3.Handler) ircv3.Handler {
cur := next
for _, v := range cs.enabled {
cur = v.Middleware(next)
}
return ircv3.HandlerFunc(func(w ircv3.MessageWriter, e *ircv3.Event) {
e = e.WithContext(context.WithValue(e.Context(), registrarKey, *cs))
// reset signal, set negotiating to 1
if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" {
for k := range waiting {
delete(waiting, k)
}
for _, v := range cs.enabled {
waiting[v] = struct{}{}
}
cs.negotiating = 1
cs.discovered = nil
}
// done negotiating, so run the handle middleware
if cs.negotiating == 4 {
cur.Handle(w, e)
return
}
// increase negotiating stage when receive the CAP * LS response
if cs.negotiating == 2 && e.Msg.Command == "CAP" && e.Msg.Param(0) == "*" && e.Msg.Param(1) == "LS" {
for _, v := range strings.Fields(e.Msg.Param(2)) {
cs.discovered = append(cs.discovered, v)
}
cs.negotiating = 3
}
if e.Type == ircv3.EventTypeCONTROL && e.Msg.Command == "/EVENT_ON_SERVE" {
w.WriteMessage(ircv3.NewMessage("CAP", "LS", "302"))
cs.negotiating = 2
}
// run all negotiation handlers
// this allows sasl auth to happen before CAP LS happens.
for v := range waiting {
ready := v.Handle(w, e)
if ready {
delete(waiting, v)
}
}
// not done negotiating yet, so dont run handler middleware
next.Handle(w, e)
if cs.negotiating == 3 && len(waiting) == 0 {
cs.negotiating = 4
w.WriteMessage(ircv3.NewMessage("CAP", "END"))
}
})
}
}
func NegotiatingState(ctx context.Context) int {
val, ok := ctx.Value(registrarKey).(*Registrar)
if !ok {
return 0
}
return val.negotiating
}
func IsDoneNegotiating(ctx context.Context) bool {
return NegotiatingState(ctx) >= 3
}
func HasCapability(ctx context.Context, c string) bool {
val, ok := ctx.Value(registrarKey).(*Registrar)
if !ok {
return false
}
for _, v := range val.discovered {
if v == c {
return true
}
}
return false
}