diff --git a/sip/receiver.go b/sip/receiver.go index 6644ae9..367a976 100644 --- a/sip/receiver.go +++ b/sip/receiver.go @@ -8,7 +8,7 @@ import ( "time" ) -func ReceiveMessages(contact *Addr, sock *net.UDPConn, c chan *Msg, e chan error) { +func ReceiveMessages(contact *Addr, sock *net.UDPConn, c chan<- *Msg, e chan<- error) { buf := make([]byte, 2048) for { amt, addr, err := sock.ReadFromUDP(buf) diff --git a/sip/route.go b/sip/route.go index bd8ed60..e75d7c2 100644 --- a/sip/route.go +++ b/sip/route.go @@ -2,12 +2,16 @@ package sip import ( "errors" + "log" "net" ) -func RouteMessage(via *Via, contact *Addr, old *Msg) (msg *Msg, dest string, err error) { - var host string - var port uint16 +type AddressRoute struct { + Address string + Next *AddressRoute +} + +func RouteMessage(via *Via, contact *Addr, old *Msg) (msg *Msg, host string, port uint16, err error) { msg = new(Msg) *msg = *old // Start off with a shallow copy. if msg.Contact == nil { @@ -17,10 +21,9 @@ func RouteMessage(via *Via, contact *Addr, old *Msg) (msg *Msg, dest string, err if via.CompareHostPort(msg.Via) { msg.Via = msg.Via.Next } + host, port = msg.Via.Host, msg.Via.Port if received, ok := msg.Via.Params["received"]; ok { - return msg, received, nil - } else { - host, port = msg.Via.Host, msg.Via.Port + host = received } } else { if contact.CompareHostPort(msg.Route) { @@ -28,23 +31,50 @@ func RouteMessage(via *Via, contact *Addr, old *Msg) (msg *Msg, dest string, err } if msg.Route != nil { if msg.Method == "REGISTER" { - return nil, "", errors.New("Don't route REGISTER requests") + return nil, "", 0, errors.New("Don't route REGISTER requests") } if msg.Route.Uri.Params.Has("lr") { // RFC3261 16.12.1.1 Basic SIP Trapezoid - host, port = msg.Route.Uri.Host, msg.Route.Uri.GetPort() + host, port = msg.Route.Uri.Host, msg.Route.Uri.Port } else { // RFC3261 16.12.1.2: Traversing a Strict-Routing Proxy msg.Route = old.Route.Copy() msg.Route.Last().Next = &Addr{Uri: msg.Request} msg.Request = msg.Route.Uri msg.Route = msg.Route.Next - host, port = msg.Request.Host, msg.Request.GetPort() + host, port = msg.Request.Host, msg.Request.Port } } else { - host, port = msg.Request.Host, msg.Request.GetPort() + host, port = msg.Request.Host, msg.Request.Port } } - dest = net.JoinHostPort(host, portstr(port)) return } + +func RouteAddress(host string, port uint16) (routes *AddressRoute, err error) { + if net.ParseIP(host) != nil { + if port == 0 { + port = 5060 + } + return &AddressRoute{Address: net.JoinHostPort(host, portstr(port))}, nil + } + if port == 0 { + _, srvs, err := net.LookupSRV("sip", "udp", host) + if err == nil && len(srvs) > 0 { + for i := len(srvs) - 1; i >= 0; i-- { + routes = &AddressRoute{ + Address: net.JoinHostPort(srvs[i].Target, portstr(srvs[i].Port)), + Next: routes, + } + } + return routes, nil + } + log.Println("net.LookupSRV(sip, udp, %s) failed: %s", err) + port = 5060 + } + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(host, portstr(port))) + if err != nil { + return nil, err + } + return &AddressRoute{Address: addr.String()}, nil +} diff --git a/sip/trace.go b/sip/trace.go index 69b9828..94349e8 100644 --- a/sip/trace.go +++ b/sip/trace.go @@ -1,12 +1,18 @@ package sip import ( + "flag" "log" "net" "strings" "time" ) +var ( + tracing = flag.Bool("tracing", true, "Enable SIP message tracing") + timestampTagging = flag.Bool("timestampTagging", false, "Add microsecond timestamps to Via tags") +) + func trace(dir, pkt string, addr net.Addr, t time.Time) { size := len(pkt) bar := strings.Repeat("-", 72) diff --git a/sip/transport.go b/sip/transport.go index ac5760c..bcf3f75 100755 --- a/sip/transport.go +++ b/sip/transport.go @@ -5,23 +5,16 @@ package sip import ( "bytes" - "errors" - "flag" "net" "time" ) -var ( - tracing = flag.Bool("tracing", true, "Enable SIP message tracing") - timestampTagging = flag.Bool("timestampTagging", false, "Add microsecond timestamps to Via tags") -) - // Transport sends and receives SIP messages over UDP with stateless routing. type Transport struct { // Channel to which received SIP messages and errors are published. - C chan *Msg - E chan error + C <-chan *Msg + E <-chan error // Underlying UDP socket. Sock *net.UDPConn @@ -45,19 +38,21 @@ type Transport struct { // canonical address. func NewTransport(contact *Addr) (tp *Transport, err error) { saddr := net.JoinHostPort(contact.Uri.Host, portstr(contact.Uri.Port)) - c, err := net.ListenPacket("udp", saddr) + conn, err := net.ListenPacket("udp", saddr) if err != nil { return nil, err } - sock := c.(*net.UDPConn) - addr := c.LocalAddr().(*net.UDPAddr) + sock := conn.(*net.UDPConn) + addr := conn.LocalAddr().(*net.UDPAddr) contact = contact.Copy() contact.Next = nil contact.Uri.Port = uint16(addr.Port) contact.Uri.Params["transport"] = "udp" + c := make(chan *Msg, 32) + e := make(chan error, 1) tp = &Transport{ - C: make(chan *Msg, 32), - E: make(chan error, 1), + C: c, + E: e, Sock: sock, Contact: contact, Via: &Via{ @@ -65,17 +60,17 @@ func NewTransport(contact *Addr) (tp *Transport, err error) { Port: uint16(addr.Port), }, } - go ReceiveMessages(contact, sock, tp.C, tp.E) + go ReceiveMessages(contact, sock, c, e) return } // Sends a SIP message. func (tp *Transport) Send(msg *Msg) error { - msg, hostport, err := RouteMessage(tp.Via, tp.Contact, msg) + msg, host, port, err := RouteMessage(tp.Via, tp.Contact, msg) if err != nil { return err } - addr, err := net.ResolveUDPAddr("udp", hostport) + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(host, portstr(port))) if err != nil { return err } @@ -95,23 +90,3 @@ func (tp *Transport) Send(msg *Msg) error { } return nil } - -// Checks if message is acceptable, otherwise sets msg.Error and returns false. -func (tp *Transport) sanityCheck(msg *Msg) error { - if msg.MaxForwards <= 0 { - tp.Send(NewResponse(msg, StatusTooManyHops)) - return errors.New("Froot loop detected") - } - if msg.IsResponse { - if msg.Status >= 700 { - tp.Send(NewResponse(msg, StatusBadRequest)) - return errors.New("Crazy status number") - } - } else { - if msg.CSeqMethod == "" || msg.CSeqMethod != msg.Method { - tp.Send(NewResponse(msg, StatusBadRequest)) - return errors.New("Bad CSeq") - } - } - return nil -}