From 0d894a3199ef32d5f262a1cb5bc61e92cb605b1d Mon Sep 17 00:00:00 2001 From: Justine Alexandra Roberts Tunney Date: Sat, 11 Apr 2015 01:02:35 -0400 Subject: [PATCH] Fix URI escaping. --- sip/charsets.go | 116 +++++++++++++++++++++++++------------------ sip/charsets_test.go | 11 ++++ sip/escape.go | 49 ++++++++++++++++++ sip/escape_test.go | 58 ++++++++++++++++++++++ sip/params.go | 5 +- sip/quote.go | 4 ++ sip/quote_test.go | 2 +- sip/sip.rl | 8 --- sip/uri.go | 36 ++++++++------ sip/util.go | 98 ------------------------------------ 10 files changed, 212 insertions(+), 175 deletions(-) create mode 100644 sip/charsets_test.go create mode 100644 sip/escape.go create mode 100644 sip/escape_test.go diff --git a/sip/charsets.go b/sip/charsets.go index b9e2b58..6249488 100644 --- a/sip/charsets.go +++ b/sip/charsets.go @@ -1,35 +1,15 @@ -// Charset implementation using four int64 bitmasks -// -// Each charset mask is 32 bytes, which fits into a 64 byte cache line. +// Charset implementation using int64 bitmasks package sip -var tokencMask [4]uint64 -var qdtextcMask [4]uint64 - -func init() { - charsetAddRange(&tokencMask, 'a', 'z') - charsetAddRange(&tokencMask, 'A', 'Z') - charsetAddRange(&tokencMask, '0', '9') - charsetAdd(&tokencMask, '-') - charsetAdd(&tokencMask, '.') - charsetAdd(&tokencMask, '!') - charsetAdd(&tokencMask, '%') - charsetAdd(&tokencMask, '*') - charsetAdd(&tokencMask, '_') - charsetAdd(&tokencMask, '+') - charsetAdd(&tokencMask, '`') - charsetAdd(&tokencMask, '\'') - charsetAdd(&tokencMask, '~') - - charsetAdd(&qdtextcMask, '\r') - charsetAdd(&qdtextcMask, '\n') - charsetAdd(&qdtextcMask, '\t') - charsetAdd(&qdtextcMask, ' ') - charsetAdd(&qdtextcMask, '!') - charsetAddRange(&qdtextcMask, 0x23, 0x5B) - charsetAddRange(&qdtextcMask, 0x5D, 0x7E) -} +var ( + tokencMask [2]uint64 + qdtextcMask [2]uint64 + usercMask [2]uint64 + passcMask [2]uint64 + paramcMask [2]uint64 + headercMask [2]uint64 +) func tokenc(c byte) bool { return charsetContains(&tokencMask, c) @@ -39,6 +19,22 @@ func qdtextc(c byte) bool { return charsetContains(&qdtextcMask, c) } +func userc(c byte) bool { + return charsetContains(&usercMask, c) +} + +func passc(c byte) bool { + return charsetContains(&passcMask, c) +} + +func paramc(c byte) bool { + return charsetContains(¶mcMask, c) +} + +func headerc(c byte) bool { + return charsetContains(&headercMask, c) +} + func qdtextesc(c byte) bool { return 0x00 <= c && c <= 0x09 || 0x0B <= c && c <= 0x0C || @@ -49,31 +45,53 @@ func whitespacec(c byte) bool { return c == ' ' || c == '\t' || c == '\r' || c == '\n' } -func charsetContains(mask *[4]uint64, i byte) bool { - return mask[i/64]&(1<<(i%64)) != 0 +func init() { + charsetAddAlphaNumeric(&tokencMask) + charsetAdd(&tokencMask, '-', '.', '!', '%', '*', '_', '+', '`', '\'', '~') + + charsetAddRange(&qdtextcMask, 0x23, 0x5B) + charsetAddRange(&qdtextcMask, 0x5D, 0x7E) + charsetAdd(&qdtextcMask, '\r', '\n', '\t', ' ', '!') + + charsetAddAlphaNumeric(&usercMask) + charsetAddMark(&usercMask) + charsetAdd(&usercMask, '&', '=', '+', '$', ',', ';', '?', '/') + + charsetAddAlphaNumeric(&passcMask) + charsetAddMark(&passcMask) + charsetAdd(&passcMask, '&', '=', '+', '$', ',') + + charsetAddAlphaNumeric(¶mcMask) + charsetAddMark(¶mcMask) + charsetAdd(¶mcMask, '[', ']', '/', ':', '&', '+', '$') + + charsetAddAlphaNumeric(&headercMask) + charsetAddMark(&headercMask) + charsetAdd(&headercMask, '[', ']', '/', '?', ':', '+', '$') +} + +func charsetContains(mask *[2]uint64, i byte) bool { + return i < 128 && mask[i/64]&(1<<(i%64)) != 0 } -func charsetAdd(mask *[4]uint64, i byte) { - mask[i/64] |= 1 << (i % 64) +func charsetAdd(mask *[2]uint64, vi ...byte) { + for _, i := range vi { + mask[i/64] |= 1 << (i % 64) + } } -func charsetAddRange(mask *[4]uint64, a, b byte) { +func charsetAddRange(mask *[2]uint64, a, b byte) { for i := a; i <= b; i++ { charsetAdd(mask, i) } - // var m uint64 - // i := a - // j := i / 64 - // for i <= b { - // m &= 1 << (i % 64) - // i++ - // if i%64 == 0 { - // mask[j] &= m - // j = i / 64 - // m = 0 - // } - // } - // if m != 0 { - // mask[j] |= m - // } +} + +func charsetAddMark(mask *[2]uint64) { + charsetAdd(mask, '-', '_', '.', '!', '~', '*', '\'', '(', ')') +} + +func charsetAddAlphaNumeric(mask *[2]uint64) { + charsetAddRange(mask, 'a', 'z') + charsetAddRange(mask, 'A', 'Z') + charsetAddRange(mask, '0', '9') } diff --git a/sip/charsets_test.go b/sip/charsets_test.go new file mode 100644 index 0000000..d0c1f15 --- /dev/null +++ b/sip/charsets_test.go @@ -0,0 +1,11 @@ +package sip + +import ( + "testing" +) + +func BenchmarkParamc(b *testing.B) { + for i := 0; i < b.N; i++ { + paramc('a') + } +} diff --git a/sip/escape.go b/sip/escape.go new file mode 100644 index 0000000..1e44471 --- /dev/null +++ b/sip/escape.go @@ -0,0 +1,49 @@ +package sip + +// escapeUser escapes a URI user, which can't use quoting. +func escapeUser(s []byte) []byte { + return escape(s, userc) +} + +// escapePass escapes a URI password, which can't use quoting. +func escapePass(s []byte) []byte { + return escape(s, passc) +} + +// escapeParam escapes a URI parameter, which can't use quoting. +func escapeParam(s []byte) []byte { + return escape(s, paramc) +} + +// escapeHeader escapes a URI header, which can't use quoting. +func escapeHeader(s []byte) []byte { + return escape(s, headerc) +} + +// escape provides arbitrary URI escaping. +func escape(s []byte, p func(byte) bool) []byte { + hc := 0 + for i := 0; i < len(s); i++ { + if !p(s[i]) { + hc++ + } + } + if hc == 0 { + return s + } + res := make([]byte, len(s)+2*hc) + j := 0 + for i := 0; i < len(s); i++ { + c := s[i] + if p(c) { + res[j] = c + j++ + } else { + res[j] = '%' + res[j+1] = hexChars[c>>4] + res[j+2] = hexChars[c%16] + j += 3 + } + } + return res +} diff --git a/sip/escape_test.go b/sip/escape_test.go new file mode 100644 index 0000000..1d0c933 --- /dev/null +++ b/sip/escape_test.go @@ -0,0 +1,58 @@ +package sip + +import ( + "testing" +) + +type escapeTest struct { + name string + in string + out string + p func([]byte) []byte +} + +var escapeTests = []escapeTest{ + + escapeTest{ + name: "Param Normal", + in: "hello", + out: "hello", + p: escapeParam, + }, + + escapeTest{ + name: "User Normal", + in: "hello", + out: "hello", + p: escapeUser, + }, + + escapeTest{ + name: "Param Spacing", + in: "hello there", + out: "hello%20there", + p: escapeParam, + }, + + escapeTest{ + name: "User Spacing", + in: "hello there", + out: "hello%20there", + p: escapeUser, + }, +} + +func TestEscape(t *testing.T) { + for _, test := range escapeTests { + out := string(test.p([]byte(test.in))) + if test.out != out { + t.Errorf("%s: %s != %s", test.name, test.out, out) + } + } +} + +func BenchmarkEscapeParam(b *testing.B) { + for i := 0; i < b.N; i++ { + escapeParam([]byte("hello there")) + } +} diff --git a/sip/params.go b/sip/params.go index b24f484..83c1c60 100644 --- a/sip/params.go +++ b/sip/params.go @@ -2,7 +2,6 @@ package sip import ( "bytes" - "github.com/jart/gosip/util" "sort" ) @@ -27,11 +26,11 @@ func (params Params) Append(b *bytes.Buffer) { sort.Strings(keys) for _, k := range keys { b.WriteByte(';') - b.WriteString(util.URLEscape(k, false)) + b.Write(escapeParam([]byte(k))) v := params[k] if v != "" { b.WriteByte('=') - b.WriteString(util.URLEscape(v, false)) + b.Write(escapeParam([]byte(v))) } } } diff --git a/sip/quote.go b/sip/quote.go index f3cabf9..3ca318b 100644 --- a/sip/quote.go +++ b/sip/quote.go @@ -4,6 +4,10 @@ import ( "bytes" ) +const ( + hexChars = "0123456789abcdef" +) + // tokencify removes all characters that aren't tokenc. func tokencify(s []byte) []byte { t := make([]byte, len(s)) diff --git a/sip/quote_test.go b/sip/quote_test.go index bcc2dfc..85f3981 100644 --- a/sip/quote_test.go +++ b/sip/quote_test.go @@ -59,7 +59,7 @@ func TestQuote(t *testing.T) { for _, test := range quoteTests { out := string(quote([]byte(test.in))) if test.out != out { - t.Error(test.name, test.out, "!=", out) + t.Errorf("%s: %s != %s", test.name, test.out, out) } } } diff --git a/sip/sip.rl b/sip/sip.rl index 9ca6b1e..325aafb 100644 --- a/sip/sip.rl +++ b/sip/sip.rl @@ -80,10 +80,6 @@ action space { amt++ } -action collapse { - amt = appendCollapse(buf, amt, fc) -} - action hexHi { hex = unhex(fc) * 16 } @@ -94,10 +90,6 @@ action hexLo { amt++ } -action lower { - amt = appendLower(buf, amt, fc) -} - action Method { msg.Method = string(data[mark:p]) } diff --git a/sip/uri.go b/sip/uri.go index baa0fc2..f0b58bc 100755 --- a/sip/uri.go +++ b/sip/uri.go @@ -76,27 +76,31 @@ func (uri *URI) String() string { func (uri *URI) Append(b *bytes.Buffer) { if uri.Scheme == "" { - uri.Scheme = "sip" + b.WriteString("sip:") + } else { + b.WriteString(uri.Scheme) + b.WriteByte(':') } - b.WriteString(uri.Scheme) - b.WriteString(":") if uri.User != "" { if uri.Pass != "" { - b.WriteString(util.URLEscape(uri.User, false)) - b.WriteString(":") - b.WriteString(util.URLEscape(uri.Pass, false)) + b.Write(escapeUser([]byte(uri.User))) + b.WriteByte(':') + b.Write(escapePass([]byte(uri.Pass))) } else { - b.WriteString(util.URLEscape(uri.User, false)) + b.Write(escapeUser([]byte(uri.User))) } - b.WriteString("@") + b.WriteByte('@') } if util.IsIPv6(uri.Host) { - b.WriteString("[" + util.URLEscape(uri.Host, false) + "]") + b.WriteByte('[') + b.WriteString(uri.Host) + b.WriteByte(']') } else { - b.WriteString(util.URLEscape(uri.Host, false)) + b.WriteString(uri.Host) } if uri.Port > 0 { - b.WriteString(":" + portstr((uri.Port))) + b.WriteByte(':') + b.WriteString(portstr(uri.Port)) } uri.Params.Append(b) uri.Headers.Append(b) @@ -143,16 +147,16 @@ func (headers URIHeaders) Append(b *bytes.Buffer) { first := true for _, k := range keys { if first { - b.WriteString("?") + b.WriteByte('?') first = false } else { - b.WriteString("&") + b.WriteByte('&') } - b.WriteString(util.URLEscape(k, false)) + b.Write(escapeHeader([]byte(k))) v := headers[k] if v != "" { - b.WriteString("=") - b.WriteString(util.URLEscape(v, false)) + b.WriteByte('=') + b.Write(escapeHeader([]byte(v))) } } } diff --git a/sip/util.go b/sip/util.go index dc1e235..5a969c1 100644 --- a/sip/util.go +++ b/sip/util.go @@ -1,9 +1,7 @@ package sip import ( - "github.com/jart/gosip/util" "strconv" - "strings" "time" ) @@ -22,79 +20,6 @@ func portstr(port uint16) string { return strconv.FormatInt(int64(port), 10) } -func extractHostPort(s string) (s2, host string, port uint16, err error) { - if s == "" { - err = URIMissingHost - } else { - if s[0] == '[' { // quoted/ipv6: sip:[dead:beef::666]:5060 - n := strings.Index(s, "]") - if n < 0 { - err = URIMissingHost - } - host, s = s[1:n], s[n+1:] - if s != "" && s[0] == ':' { // we has a port too - s = s[1:] - s, port, err = extractPort(s) - } - } else { // non-quoted host: sip:1.2.3.4:5060 - switch n := strings.IndexAny(s, delims); { - case n < 0: - host, s = s, "" - case s[n] == ':': // host:port - host, s = s[0:n], s[n+1:] - s, port, err = extractPort(s) - default: - host, s = s[0:n], s[n:] - } - } - } - return s, host, port, err -} - -func parseParams(s string) (res Params) { - items := strings.Split(s, ";") - if items == nil || len(items) == 0 || items[0] == "" { - return - } - res = make(Params, len(items)) - for _, item := range items { - if item == "" { - continue - } - n := strings.Index(item, "=") - var k, v string - if n > 0 { - k, v = item[0:n], item[n+1:] - } else { - k, v = item, "" - } - k, kerr := util.URLUnescape(k, false) - v, verr := util.URLUnescape(v, false) - if kerr != nil || verr != nil { - continue - } - res[k] = v - } - return res -} - -func extractPort(s string) (s2 string, port uint16, err error) { - if n := strings.IndexAny(s, delims); n > 0 { - port, err = parsePort(s[0:n]) - s = s[n:] - } else { - port, err = parsePort(s) - s = "" - } - return s, port, err -} - -func parsePort(s string) (port uint16, err error) { - i, err := strconv.ParseUint(s, 10, 16) - port = uint16(i) - return -} - func unhex(b byte) byte { switch { case '0' <= b && b <= '9': @@ -106,26 +31,3 @@ func unhex(b byte) byte { } return 0 } - -func appendCollapse(buf []byte, amt int, fc byte) int { - switch fc { - case ' ', '\t', '\r', '\n': - if amt == 0 || buf[amt-1] != ' ' { - buf[amt] = ' ' - amt++ - } - default: - buf[amt] = fc - amt++ - } - return amt -} - -func appendLower(buf []byte, amt int, fc byte) int { - if 'A' <= fc && fc <= 'Z' { - buf[amt] = fc + 0x20 - } else { - buf[amt] = fc - } - return amt + 1 -}