Browse Source

Fix URI escaping.

pull/2/head
Justine Alexandra Roberts Tunney 11 years ago
parent
commit
0d894a3199
10 changed files with 212 additions and 175 deletions
  1. +67
    -49
      sip/charsets.go
  2. +11
    -0
      sip/charsets_test.go
  3. +49
    -0
      sip/escape.go
  4. +58
    -0
      sip/escape_test.go
  5. +2
    -3
      sip/params.go
  6. +4
    -0
      sip/quote.go
  7. +1
    -1
      sip/quote_test.go
  8. +0
    -8
      sip/sip.rl
  9. +20
    -16
      sip/uri.go
  10. +0
    -98
      sip/util.go

+ 67
- 49
sip/charsets.go View File

@ -1,35 +1,15 @@
// Charset implementation using four int64 bitmasks
//
// Each charset mask is 32 bytes, which fits into a 64 byte cache line.
// Charset implementation using int64 bitmasks
package sip
var tokencMask [4]uint64
var qdtextcMask [4]uint64
func init() {
charsetAddRange(&tokencMask, 'a', 'z')
charsetAddRange(&tokencMask, 'A', 'Z')
charsetAddRange(&tokencMask, '0', '9')
charsetAdd(&tokencMask, '-')
charsetAdd(&tokencMask, '.')
charsetAdd(&tokencMask, '!')
charsetAdd(&tokencMask, '%')
charsetAdd(&tokencMask, '*')
charsetAdd(&tokencMask, '_')
charsetAdd(&tokencMask, '+')
charsetAdd(&tokencMask, '`')
charsetAdd(&tokencMask, '\'')
charsetAdd(&tokencMask, '~')
charsetAdd(&qdtextcMask, '\r')
charsetAdd(&qdtextcMask, '\n')
charsetAdd(&qdtextcMask, '\t')
charsetAdd(&qdtextcMask, ' ')
charsetAdd(&qdtextcMask, '!')
charsetAddRange(&qdtextcMask, 0x23, 0x5B)
charsetAddRange(&qdtextcMask, 0x5D, 0x7E)
}
var (
tokencMask [2]uint64
qdtextcMask [2]uint64
usercMask [2]uint64
passcMask [2]uint64
paramcMask [2]uint64
headercMask [2]uint64
)
func tokenc(c byte) bool {
return charsetContains(&tokencMask, c)
@ -39,6 +19,22 @@ func qdtextc(c byte) bool {
return charsetContains(&qdtextcMask, c)
}
func userc(c byte) bool {
return charsetContains(&usercMask, c)
}
func passc(c byte) bool {
return charsetContains(&passcMask, c)
}
func paramc(c byte) bool {
return charsetContains(&paramcMask, c)
}
func headerc(c byte) bool {
return charsetContains(&headercMask, c)
}
func qdtextesc(c byte) bool {
return 0x00 <= c && c <= 0x09 ||
0x0B <= c && c <= 0x0C ||
@ -49,31 +45,53 @@ func whitespacec(c byte) bool {
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
}
func charsetContains(mask *[4]uint64, i byte) bool {
return mask[i/64]&(1<<(i%64)) != 0
func init() {
charsetAddAlphaNumeric(&tokencMask)
charsetAdd(&tokencMask, '-', '.', '!', '%', '*', '_', '+', '`', '\'', '~')
charsetAddRange(&qdtextcMask, 0x23, 0x5B)
charsetAddRange(&qdtextcMask, 0x5D, 0x7E)
charsetAdd(&qdtextcMask, '\r', '\n', '\t', ' ', '!')
charsetAddAlphaNumeric(&usercMask)
charsetAddMark(&usercMask)
charsetAdd(&usercMask, '&', '=', '+', '$', ',', ';', '?', '/')
charsetAddAlphaNumeric(&passcMask)
charsetAddMark(&passcMask)
charsetAdd(&passcMask, '&', '=', '+', '$', ',')
charsetAddAlphaNumeric(&paramcMask)
charsetAddMark(&paramcMask)
charsetAdd(&paramcMask, '[', ']', '/', ':', '&', '+', '$')
charsetAddAlphaNumeric(&headercMask)
charsetAddMark(&headercMask)
charsetAdd(&headercMask, '[', ']', '/', '?', ':', '+', '$')
}
func charsetContains(mask *[2]uint64, i byte) bool {
return i < 128 && mask[i/64]&(1<<(i%64)) != 0
}
func charsetAdd(mask *[4]uint64, i byte) {
mask[i/64] |= 1 << (i % 64)
func charsetAdd(mask *[2]uint64, vi ...byte) {
for _, i := range vi {
mask[i/64] |= 1 << (i % 64)
}
}
func charsetAddRange(mask *[4]uint64, a, b byte) {
func charsetAddRange(mask *[2]uint64, a, b byte) {
for i := a; i <= b; i++ {
charsetAdd(mask, i)
}
// var m uint64
// i := a
// j := i / 64
// for i <= b {
// m &= 1 << (i % 64)
// i++
// if i%64 == 0 {
// mask[j] &= m
// j = i / 64
// m = 0
// }
// }
// if m != 0 {
// mask[j] |= m
// }
}
func charsetAddMark(mask *[2]uint64) {
charsetAdd(mask, '-', '_', '.', '!', '~', '*', '\'', '(', ')')
}
func charsetAddAlphaNumeric(mask *[2]uint64) {
charsetAddRange(mask, 'a', 'z')
charsetAddRange(mask, 'A', 'Z')
charsetAddRange(mask, '0', '9')
}

+ 11
- 0
sip/charsets_test.go View File

@ -0,0 +1,11 @@
package sip
import (
"testing"
)
func BenchmarkParamc(b *testing.B) {
for i := 0; i < b.N; i++ {
paramc('a')
}
}

+ 49
- 0
sip/escape.go View File

@ -0,0 +1,49 @@
package sip
// escapeUser escapes a URI user, which can't use quoting.
func escapeUser(s []byte) []byte {
return escape(s, userc)
}
// escapePass escapes a URI password, which can't use quoting.
func escapePass(s []byte) []byte {
return escape(s, passc)
}
// escapeParam escapes a URI parameter, which can't use quoting.
func escapeParam(s []byte) []byte {
return escape(s, paramc)
}
// escapeHeader escapes a URI header, which can't use quoting.
func escapeHeader(s []byte) []byte {
return escape(s, headerc)
}
// escape provides arbitrary URI escaping.
func escape(s []byte, p func(byte) bool) []byte {
hc := 0
for i := 0; i < len(s); i++ {
if !p(s[i]) {
hc++
}
}
if hc == 0 {
return s
}
res := make([]byte, len(s)+2*hc)
j := 0
for i := 0; i < len(s); i++ {
c := s[i]
if p(c) {
res[j] = c
j++
} else {
res[j] = '%'
res[j+1] = hexChars[c>>4]
res[j+2] = hexChars[c%16]
j += 3
}
}
return res
}

+ 58
- 0
sip/escape_test.go View File

@ -0,0 +1,58 @@
package sip
import (
"testing"
)
type escapeTest struct {
name string
in string
out string
p func([]byte) []byte
}
var escapeTests = []escapeTest{
escapeTest{
name: "Param Normal",
in: "hello",
out: "hello",
p: escapeParam,
},
escapeTest{
name: "User Normal",
in: "hello",
out: "hello",
p: escapeUser,
},
escapeTest{
name: "Param Spacing",
in: "hello there",
out: "hello%20there",
p: escapeParam,
},
escapeTest{
name: "User Spacing",
in: "hello there",
out: "hello%20there",
p: escapeUser,
},
}
func TestEscape(t *testing.T) {
for _, test := range escapeTests {
out := string(test.p([]byte(test.in)))
if test.out != out {
t.Errorf("%s: %s != %s", test.name, test.out, out)
}
}
}
func BenchmarkEscapeParam(b *testing.B) {
for i := 0; i < b.N; i++ {
escapeParam([]byte("hello there"))
}
}

+ 2
- 3
sip/params.go View File

@ -2,7 +2,6 @@ package sip
import (
"bytes"
"github.com/jart/gosip/util"
"sort"
)
@ -27,11 +26,11 @@ func (params Params) Append(b *bytes.Buffer) {
sort.Strings(keys)
for _, k := range keys {
b.WriteByte(';')
b.WriteString(util.URLEscape(k, false))
b.Write(escapeParam([]byte(k)))
v := params[k]
if v != "" {
b.WriteByte('=')
b.WriteString(util.URLEscape(v, false))
b.Write(escapeParam([]byte(v)))
}
}
}


+ 4
- 0
sip/quote.go View File

@ -4,6 +4,10 @@ import (
"bytes"
)
const (
hexChars = "0123456789abcdef"
)
// tokencify removes all characters that aren't tokenc.
func tokencify(s []byte) []byte {
t := make([]byte, len(s))


+ 1
- 1
sip/quote_test.go View File

@ -59,7 +59,7 @@ func TestQuote(t *testing.T) {
for _, test := range quoteTests {
out := string(quote([]byte(test.in)))
if test.out != out {
t.Error(test.name, test.out, "!=", out)
t.Errorf("%s: %s != %s", test.name, test.out, out)
}
}
}

+ 0
- 8
sip/sip.rl View File

@ -80,10 +80,6 @@ action space {
amt++
}
action collapse {
amt = appendCollapse(buf, amt, fc)
}
action hexHi {
hex = unhex(fc) * 16
}
@ -94,10 +90,6 @@ action hexLo {
amt++
}
action lower {
amt = appendLower(buf, amt, fc)
}
action Method {
msg.Method = string(data[mark:p])
}


+ 20
- 16
sip/uri.go View File

@ -76,27 +76,31 @@ func (uri *URI) String() string {
func (uri *URI) Append(b *bytes.Buffer) {
if uri.Scheme == "" {
uri.Scheme = "sip"
b.WriteString("sip:")
} else {
b.WriteString(uri.Scheme)
b.WriteByte(':')
}
b.WriteString(uri.Scheme)
b.WriteString(":")
if uri.User != "" {
if uri.Pass != "" {
b.WriteString(util.URLEscape(uri.User, false))
b.WriteString(":")
b.WriteString(util.URLEscape(uri.Pass, false))
b.Write(escapeUser([]byte(uri.User)))
b.WriteByte(':')
b.Write(escapePass([]byte(uri.Pass)))
} else {
b.WriteString(util.URLEscape(uri.User, false))
b.Write(escapeUser([]byte(uri.User)))
}
b.WriteString("@")
b.WriteByte('@')
}
if util.IsIPv6(uri.Host) {
b.WriteString("[" + util.URLEscape(uri.Host, false) + "]")
b.WriteByte('[')
b.WriteString(uri.Host)
b.WriteByte(']')
} else {
b.WriteString(util.URLEscape(uri.Host, false))
b.WriteString(uri.Host)
}
if uri.Port > 0 {
b.WriteString(":" + portstr((uri.Port)))
b.WriteByte(':')
b.WriteString(portstr(uri.Port))
}
uri.Params.Append(b)
uri.Headers.Append(b)
@ -143,16 +147,16 @@ func (headers URIHeaders) Append(b *bytes.Buffer) {
first := true
for _, k := range keys {
if first {
b.WriteString("?")
b.WriteByte('?')
first = false
} else {
b.WriteString("&")
b.WriteByte('&')
}
b.WriteString(util.URLEscape(k, false))
b.Write(escapeHeader([]byte(k)))
v := headers[k]
if v != "" {
b.WriteString("=")
b.WriteString(util.URLEscape(v, false))
b.WriteByte('=')
b.Write(escapeHeader([]byte(v)))
}
}
}


+ 0
- 98
sip/util.go View File

@ -1,9 +1,7 @@
package sip
import (
"github.com/jart/gosip/util"
"strconv"
"strings"
"time"
)
@ -22,79 +20,6 @@ func portstr(port uint16) string {
return strconv.FormatInt(int64(port), 10)
}
func extractHostPort(s string) (s2, host string, port uint16, err error) {
if s == "" {
err = URIMissingHost
} else {
if s[0] == '[' { // quoted/ipv6: sip:[dead:beef::666]:5060
n := strings.Index(s, "]")
if n < 0 {
err = URIMissingHost
}
host, s = s[1:n], s[n+1:]
if s != "" && s[0] == ':' { // we has a port too
s = s[1:]
s, port, err = extractPort(s)
}
} else { // non-quoted host: sip:1.2.3.4:5060
switch n := strings.IndexAny(s, delims); {
case n < 0:
host, s = s, ""
case s[n] == ':': // host:port
host, s = s[0:n], s[n+1:]
s, port, err = extractPort(s)
default:
host, s = s[0:n], s[n:]
}
}
}
return s, host, port, err
}
func parseParams(s string) (res Params) {
items := strings.Split(s, ";")
if items == nil || len(items) == 0 || items[0] == "" {
return
}
res = make(Params, len(items))
for _, item := range items {
if item == "" {
continue
}
n := strings.Index(item, "=")
var k, v string
if n > 0 {
k, v = item[0:n], item[n+1:]
} else {
k, v = item, ""
}
k, kerr := util.URLUnescape(k, false)
v, verr := util.URLUnescape(v, false)
if kerr != nil || verr != nil {
continue
}
res[k] = v
}
return res
}
func extractPort(s string) (s2 string, port uint16, err error) {
if n := strings.IndexAny(s, delims); n > 0 {
port, err = parsePort(s[0:n])
s = s[n:]
} else {
port, err = parsePort(s)
s = ""
}
return s, port, err
}
func parsePort(s string) (port uint16, err error) {
i, err := strconv.ParseUint(s, 10, 16)
port = uint16(i)
return
}
func unhex(b byte) byte {
switch {
case '0' <= b && b <= '9':
@ -106,26 +31,3 @@ func unhex(b byte) byte {
}
return 0
}
func appendCollapse(buf []byte, amt int, fc byte) int {
switch fc {
case ' ', '\t', '\r', '\n':
if amt == 0 || buf[amt-1] != ' ' {
buf[amt] = ' '
amt++
}
default:
buf[amt] = fc
amt++
}
return amt
}
func appendLower(buf []byte, amt int, fc byte) int {
if 'A' <= fc && fc <= 'Z' {
buf[amt] = fc + 0x20
} else {
buf[amt] = fc
}
return amt + 1
}

Loading…
Cancel
Save