irc/pkg/ircdecoder/decoder.go
2024-05-14 00:51:42 -05:00

198 lines
3.5 KiB
Go

package ircdecoder
import (
"errors"
"fmt"
"io"
"strings"
"unicode/utf8"
"tuxpa.in/a/irc/pkg/ircv3"
)
type Decoder struct {
}
func (d *Decoder) readByte(r io.Reader) (byte, error) {
var o [1]byte
if c, ok := r.(io.ByteReader); ok {
return c.ReadByte()
}
_, err := io.ReadFull(r, o[:])
if err != nil {
return 0, err
}
return o[0], nil
}
// read a message from the stream
func (d *Decoder) Decode(r io.Reader, msg *ircv3.Message) error {
return d.decode(r, msg)
}
func (d *Decoder) decodeTags(r io.Reader, msg *ircv3.Message) error {
// we assume we have already read the @
if msg.Tags == nil {
msg.Tags = make(ircv3.Tags)
}
kb := new(strings.Builder)
vb := new(strings.Builder)
readingValue := false
for {
// keep reading until space
b, err := d.readByte(r)
if err != nil {
return err
}
if b == ';' {
msg.Tags.Set(kb.String(), vb.String())
readingValue = false
continue
} else if b == '=' {
readingValue = true
continue
}
if b == 0x20 {
kstr := kb.String()
if !utf8.ValidString(kstr) {
return fmt.Errorf("non utf-8 tag key")
}
msg.Tags.Set(kstr, ircv3.UnescapeTagValue(vb.String()))
readingValue = false
break
}
if readingValue {
vb.WriteByte(b)
} else {
// TODO: technically we should check the validity of key bytes, not scan for utf8 at the end
// <key_name> ::= <non-empty sequence of ascii letters, digits, hyphens ('-')>
kb.WriteByte(b)
}
}
return nil
}
func (d *Decoder) decodeSource(r io.Reader, msg *ircv3.Message) error {
// we assume we have already read the :
buf := new(strings.Builder)
for {
// keep reading until space
b, err := d.readByte(r)
if err != nil {
return err
}
if b == 0x20 {
break
}
buf.WriteByte(b)
}
nuh, err := ircv3.ParseNUH(buf.String())
if err != nil {
return err
}
msg.Source = &nuh
return nil
}
// read a message from the stream
func (d *Decoder) decode(r io.Reader, msg *ircv3.Message) error {
b, err := d.readByte(r)
if err != nil {
return err
}
switch b {
case '@':
if err := d.decodeTags(r, msg); err != nil {
return err
}
b, err = d.readByte(r)
if err != nil {
return err
}
if b == ':' {
if err := d.decodeSource(r, msg); err != nil {
return err
}
b, err = d.readByte(r)
if err != nil {
return err
}
}
case ':':
if err := d.decodeSource(r, msg); err != nil {
return err
}
b, err = d.readByte(r)
if err != nil {
return err
}
default:
}
cb := new(strings.Builder)
// at this point we've no matter waht read the first byte of the command in b
cb.WriteByte(b)
// add a limit reader for the irc size limit
r = io.LimitReader(r, 511)
// read until first space
for {
b, err := d.readByte(r)
if err != nil {
if errors.Is(err, io.EOF) {
msg.Command = cb.String()
return nil
}
return err
}
if b == 0x20 {
break
}
cb.WriteByte(b)
}
msg.Command = cb.String()
cb.Reset()
// now read the params
var trailing bool
var lastCr bool
for {
b, err := d.readByte(r)
if err != nil {
if errors.Is(err, io.EOF) {
if cb.Len() > 0 {
msg.Params = append(msg.Params, cb.String())
}
return nil
}
return err
}
if cb.Len() == 0 {
if b == ':' {
trailing = true
continue
}
}
if !trailing {
if b == 0x20 {
msg.Params = append(msg.Params, cb.String())
cb.Reset()
continue
}
}
if b == '\r' {
lastCr = true
continue
}
if lastCr {
if b == '\n' {
msg.Params = append(msg.Params, cb.String())
return nil
} else {
cb.WriteByte('\r')
}
}
cb.WriteByte(b)
}
}