diff --git a/lib/socket.c b/lib/socket.c index 380fecbd0..c7f4de994 100644 --- a/lib/socket.c +++ b/lib/socket.c @@ -27,6 +27,8 @@ static int __ip_connect(socket_t *s, const endpoint_t *); static int __ip_listen(socket_t *s, int backlog); static int __ip_accept(socket_t *s, socket_t *new_sock); static int __ip_timestamping(socket_t *s); +static int __ip4_pktinfo(socket_t *s); +static int __ip6_pktinfo(socket_t *s); static int __ip4_sockaddr2endpoint(endpoint_t *, const void *); static int __ip6_sockaddr2endpoint(endpoint_t *, const void *); static int __ip4_endpoint2sockaddr(void *, const endpoint_t *); @@ -81,6 +83,7 @@ static struct socket_family __socket_families[__SF_LAST] = { .listen = __ip_listen, .accept = __ip_accept, .timestamping = __ip_timestamping, + .pktinfo = __ip4_pktinfo, .recvfrom = __ip_recvfrom, .recvfrom_ts = __ip_recvfrom_ts, .sendmsg = __ip_sendmsg, @@ -113,6 +116,7 @@ static struct socket_family __socket_families[__SF_LAST] = { .listen = __ip_listen, .accept = __ip_accept, .timestamping = __ip_timestamping, + .pktinfo = __ip6_pktinfo, .recvfrom = __ip_recvfrom, .recvfrom_ts = __ip_recvfrom_ts, .sendmsg = __ip_sendmsg, @@ -396,6 +400,18 @@ static int __ip_timestamping(socket_t *s) { return -1; return 0; } +static int __ip4_pktinfo(socket_t *s) { + int one = 1; + if (setsockopt(s->fd, IPPROTO_IP, IP_PKTINFO, &one, sizeof(one))) + return -1; + return 0; +} +static int __ip6_pktinfo(socket_t *s) { + int one = 1; + if (setsockopt(s->fd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one))) + return -1; + return 0; +} static void __ip4_endpoint2kernel(struct re_address *ra, const endpoint_t *ep) { ZERO(*ra); ra->family = AF_INET; diff --git a/lib/socket.h b/lib/socket.h index f9b4e07c2..3ca4a4aef 100644 --- a/lib/socket.h +++ b/lib/socket.h @@ -72,6 +72,7 @@ struct socket_family { int (*listen)(socket_t *, int); int (*accept)(socket_t *, socket_t *); int (*timestamping)(socket_t *); + int (*pktinfo)(socket_t *); ssize_t (*recvfrom)(socket_t *, void *, size_t, endpoint_t *); ssize_t (*recvfrom_ts)(socket_t *, void *, size_t, endpoint_t *, struct timeval *); ssize_t (*sendmsg)(socket_t *, struct msghdr *, const endpoint_t *); @@ -172,6 +173,7 @@ INLINE int is_addr_unspecified(const sockaddr_t *a) { #define socket_sendto(s,a...) (s)->family->sendto((s), a) #define socket_error(s) (s)->family->error((s)) #define socket_timestamping(s) (s)->family->timestamping((s)) +#define socket_pktinfo(s) (s)->family->pktinfo((s)) INLINE ssize_t socket_sendiov(socket_t *s, const struct iovec *v, unsigned int len, const endpoint_t *dst) { struct msghdr mh; ZERO(mh);