Implement UID resolution via INET_DIAG

This speeds up UID resolution with root capture
This commit is contained in:
emanuele-f 2022-04-07 00:36:24 +02:00
parent 72a1083b3f
commit a626e11c62
4 changed files with 164 additions and 19 deletions

View File

@ -23,11 +23,14 @@
#include <sys/socket.h> #include <sys/socket.h>
#include <linux/netlink.h> #include <linux/netlink.h>
#include <linux/rtnetlink.h> #include <linux/rtnetlink.h>
#include <linux/sock_diag.h>
#include <linux/inet_diag.h>
#include <netinet/in.h> #include <netinet/in.h>
#include <net/if.h> #include <net/if.h>
#include "nl_utils.h" #include "nl_utils.h"
#include "common/uid_resolver.h"
int nl_socket(uint32_t groups) { int nl_route_socket(uint32_t groups) {
struct sockaddr_nl snl; struct sockaddr_nl snl;
int sock; int sock;
@ -146,3 +149,120 @@ out:
return(rv); return(rv);
} }
/* ******************************************************* */
static int diag_uid_lookup(int nlsock, int family, int ipproto,
const pd_sockaddr_t *local, const pd_sockaddr_t *remote,
int flags) {
struct sockaddr_nl snl = {0};
struct msghdr msg = {0};
struct iovec iov;
u_char buf[512];
static int seq = 0;
ssize_t rv;
struct nlmsghdr *nmsg = (struct nlmsghdr*) buf;
struct inet_diag_req_v2 *req = (struct inet_diag_req_v2*) (nmsg + 1);
memset(req, 0, sizeof(*req));
req->sdiag_family = family;
req->sdiag_protocol = ipproto;
req->idiag_states = -1 /* ANY state */;
req->id.idiag_sport = local->port;
req->id.idiag_dport = remote->port;
req->id.idiag_cookie[0] = -1, req->id.idiag_cookie[1] = -1; /* no cookie */
if(family == AF_INET) {
memcpy(req->id.idiag_src, &local->addr.ip4, 4);
memcpy(req->id.idiag_dst, &remote->addr.ip4, 4);
} else {
memcpy(req->id.idiag_src, &local->addr.ip6, 16);
memcpy(req->id.idiag_dst, &remote->addr.ip6, 16);
}
memset(nmsg, 0, sizeof(*nmsg));
nmsg->nlmsg_len = sizeof(*nmsg) + sizeof(*req);
nmsg->nlmsg_type = SOCK_DIAG_BY_FAMILY;
nmsg->nlmsg_flags = flags;
nmsg->nlmsg_seq = ++seq;
iov.iov_base = (void*) nmsg;
iov.iov_len = nmsg->nlmsg_len;
snl.nl_family = AF_NETLINK;
msg.msg_name = (void*) &snl;
msg.msg_namelen = sizeof(snl);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
// Send request
if(sendmsg(nlsock, &msg, 0) < 0)
return -1;
iov.iov_base = buf;
iov.iov_len = sizeof(buf);
// Recv reply
if((rv = recvmsg(nlsock, &msg, 0)) <= 0)
return -2;
// NOTE: nmsg points to buf
if(nmsg->nlmsg_len < (int)sizeof(*nmsg) || nmsg->nlmsg_len > rv ||
nmsg->nlmsg_seq != seq) {
errno = EINVAL;
return -3;
}
if(nmsg->nlmsg_type == NLMSG_ERROR)
return -4;
struct inet_diag_msg *diag_msg = (struct inet_diag_msg*) NLMSG_DATA(nmsg);
return diag_msg->idiag_uid;
}
/* ******************************************************* */
int nl_get_uid(int nlsock, const zdtun_5tuple_t *tuple) {
int uid;
int family = (tuple->ipver == 4) ? AF_INET : AF_INET6;
int ipproto = tuple->ipproto;
pd_sockaddr_t src = {.addr = tuple->src_ip, .port = tuple->src_port};
pd_sockaddr_t dst = {.addr = tuple->dst_ip, .port = tuple->dst_port};
// fix to known bug with UDP: https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html
const pd_sockaddr_t *local = (ipproto == IPPROTO_UDP) ? &dst : &src;
const pd_sockaddr_t *remote = (ipproto == IPPROTO_UDP) ? &src : &dst;
uid = diag_uid_lookup(nlsock, family, ipproto, local, remote, NLM_F_REQUEST);
if(uid >= 0)
return uid;
// Search for IPv4-mapped IPv6 addresses
if(family == AF_INET) {
uid = diag_uid_lookup(nlsock, AF_INET6, ipproto, local, remote, NLM_F_REQUEST);
if(uid >= 0)
return uid;
}
// For UDP it's possible for a socket to send packets to arbitrary destinations
// See InetDiagMessage.java in Android
if(ipproto == IPPROTO_UDP) {
pd_sockaddr_t wildcard = {0};
uid = diag_uid_lookup(nlsock, family, ipproto, &src, &wildcard, NLM_F_REQUEST | NLM_F_DUMP);
if(uid >= 0)
return uid;
// Search for IPv4-mapped IPv6 addresses
if(family == AF_INET) {
uid = diag_uid_lookup(nlsock, AF_INET6, ipproto, &src, &wildcard, NLM_F_REQUEST | NLM_F_DUMP);
if(uid >= 0)
return uid;
}
}
return UID_UNKNOWN;
}

View File

@ -21,6 +21,10 @@
#define __NL_UTILS_H__ #define __NL_UTILS_H__
#include <stdint.h> #include <stdint.h>
#include <zdtun.h>
/* >= 8192 to avoid truncation, see "man 7 netlink" */
#define NL_BUFFER_SIZE 8192
typedef struct { typedef struct {
union { union {
@ -29,6 +33,11 @@ typedef struct {
}; };
} __attribute__((packed)) addr_t; } __attribute__((packed)) addr_t;
typedef struct {
zdtun_ip_t addr;
uint16_t port;
} pd_sockaddr_t;
typedef struct { typedef struct {
addr_t gateway; addr_t gateway;
int ifidx; int ifidx;
@ -36,6 +45,7 @@ typedef struct {
} route_info_t; } route_info_t;
int nl_get_route(int af, const addr_t *addr, route_info_t *out); int nl_get_route(int af, const addr_t *addr, route_info_t *out);
int nl_socket(uint32_t groups); int nl_route_socket(uint32_t groups);
int nl_get_uid(int nlsock, const zdtun_5tuple_t *tuple);
#endif #endif

View File

@ -227,8 +227,10 @@ static int create_pid_file() {
static void finish_pcapd_capture(pcapd_runtime_t *rt) { static void finish_pcapd_capture(pcapd_runtime_t *rt) {
if(rt->client > 0) if(rt->client > 0)
close(rt->client); close(rt->client);
if(rt->nlsock > 0) if(rt->nlroute_sock > 0)
close(rt->nlsock); close(rt->nlroute_sock);
if(rt->nldiag_sock > 0)
close(rt->nldiag_sock);
if(rt->lru) if(rt->lru)
uid_lru_destroy(rt->lru); uid_lru_destroy(rt->lru);
if(rt->resolver) if(rt->resolver)
@ -259,7 +261,8 @@ static int init_pcapd_capture(pcapd_runtime_t *rt, pcapd_conf_t *conf) {
signal(SIGPIPE, SIG_IGN); signal(SIGPIPE, SIG_IGN);
} }
rt->nlsock = -1; rt->nlroute_sock = -1;
rt->nldiag_sock = -1;
rt->client = -1; rt->client = -1;
rt->conf = conf; rt->conf = conf;
@ -281,15 +284,19 @@ static int init_pcapd_capture(pcapd_runtime_t *rt, pcapd_conf_t *conf) {
} }
if(rt->inet_iface) { if(rt->inet_iface) {
rt->nlsock = nl_socket(RTMGRP_IPV4_ROUTE | RTMGRP_IPV4_IFADDR | RTMGRP_IPV4_RULE | rt->nlroute_sock = nl_route_socket(RTMGRP_IPV4_ROUTE | RTMGRP_IPV4_IFADDR | RTMGRP_IPV4_RULE |
RTMGRP_IPV6_ROUTE | RTMGRP_IPV6_IFADDR | RTMGRP_LINK); RTMGRP_IPV6_ROUTE | RTMGRP_IPV6_IFADDR | RTMGRP_LINK);
if(rt->nlsock < 0) { if(rt->nlroute_sock < 0) {
log_e("could not create netlink socket[%d]: %s", errno, strerror(errno)); log_e("could not create netlink socket[%d]: %s", errno, strerror(errno));
goto err; goto err;
} }
rt->maxfd = max(rt->maxfd, rt->nlsock); rt->maxfd = max(rt->maxfd, rt->nlroute_sock);
} }
rt->nldiag_sock = socket(AF_NETLINK, SOCK_DGRAM, NETLINK_INET_DIAG);
if(rt->nldiag_sock < 0)
log_w("could not open NETLINK_INET_DIAG[%d]: %s", errno, strerror(errno));
signal(SIGINT, &sighandler); signal(SIGINT, &sighandler);
signal(SIGTERM, &sighandler); signal(SIGTERM, &sighandler);
signal(SIGHUP, &sighandler); signal(SIGHUP, &sighandler);
@ -533,7 +540,7 @@ static int handle_nl_message(pcapd_runtime_t *rt) {
.msg_iovlen = 1 .msg_iovlen = 1
}; };
ssize_t len = recvmsg(rt->nlsock, &msg, 0); ssize_t len = recvmsg(rt->nlroute_sock, &msg, 0);
uint8_t recheck_inet = 0; uint8_t recheck_inet = 0;
#ifdef READ_FROM_PCAP #ifdef READ_FROM_PCAP
@ -665,8 +672,8 @@ static void get_selectable_fds(pcapd_runtime_t *rt, fd_set *fds) {
if(rt->client > 0) if(rt->client > 0)
FD_SET(rt->client, fds); FD_SET(rt->client, fds);
if(rt->nlsock > 0) if(rt->nlroute_sock > 0)
FD_SET(rt->nlsock, fds); FD_SET(rt->nlroute_sock, fds);
for(int i=0; i<rt->conf->num_interfaces; i++) { for(int i=0; i<rt->conf->num_interfaces; i++) {
if(rt->ifaces[i].pf != -1) if(rt->ifaces[i].pf != -1)
@ -729,7 +736,13 @@ static int read_pkt(pcapd_runtime_t *rt, pcapd_iface_t *iface, time_t now) {
uid = uid_lru_find(rt->lru, &zpkt.tuple); uid = uid_lru_find(rt->lru, &zpkt.tuple);
if(uid == -2) { if(uid == -2) {
uid = get_uid(rt->resolver, &zpkt.tuple); if((rt->nldiag_sock > 0) && (zpkt.tuple.ipproto != IPPROTO_ICMP))
// retrieve via netlink
uid = nl_get_uid(rt->nldiag_sock, &zpkt.tuple);
else
// slow method
uid = get_uid(rt->resolver, &zpkt.tuple);
uid_lru_add(rt->lru, &zpkt.tuple, uid); uid_lru_add(rt->lru, &zpkt.tuple, uid);
} }
} }
@ -758,12 +771,13 @@ static int read_pkt(pcapd_runtime_t *rt, pcapd_iface_t *iface, time_t now) {
log_e("write failed[%d]: %s", errno, strerror(errno)); log_e("write failed[%d]: %s", errno, strerror(errno));
return -1; return -1;
} }
} else { } else if(!rt->conf->quiet) {
char buf[512]; char buf[512];
zdtun_5tuple2str(&zpkt.tuple, buf, sizeof(buf)); zdtun_5tuple2str(&zpkt.tuple, buf, sizeof(buf));
if(!rt->conf->quiet) printf("[%s:%d] %s (%u B) [%cX] (%d)\n", iface->name,
printf("[%s:%d] %s (%u B) [%cX]\n", iface->name, iface->ifid, buf, phdr.len, is_tx ? 'T' : 'R'); iface->ifid, buf, phdr.len, is_tx ? 'T' : 'R',
uid);
} }
if(iface->is_file) { if(iface->is_file) {
@ -836,11 +850,10 @@ int run_pcap_dump(pcapd_conf_t *conf) {
clock_gettime(CLOCK_MONOTONIC_COARSE, &ts); clock_gettime(CLOCK_MONOTONIC_COARSE, &ts);
time_t now = ts.tv_sec; time_t now = ts.tv_sec;
if((rt.client > 0) && FD_ISSET(rt.client, &fds)) { if((rt.client > 0) && FD_ISSET(rt.client, &fds)) {
log_i("Client closed"); log_i("Client closed");
break; break;
} else if((rt.nlsock > 0) && FD_ISSET(rt.nlsock, &fds)) { } else if((rt.nlroute_sock > 0) && FD_ISSET(rt.nlroute_sock, &fds)) {
if(handle_nl_message(&rt) < 0) { if(handle_nl_message(&rt) < 0) {
rv = -1; rv = -1;
break; break;

View File

@ -23,6 +23,7 @@
#include <stdint.h> #include <stdint.h>
#include <net/if.h> #include <net/if.h>
#include <pcap.h> #include <pcap.h>
#include "nl_utils.h"
#include "common/uid_lru.h" #include "common/uid_lru.h"
#include "common/uid_resolver.h" #include "common/uid_resolver.h"
#include "common/utils.h" #include "common/utils.h"
@ -61,9 +62,10 @@ typedef struct {
typedef struct { typedef struct {
char bpf[512]; char bpf[512];
char nlbuf[8192]; /* >= 8192 to avoid truncation, see "man 7 netlink" */ char nlbuf[NL_BUFFER_SIZE];
int nlsock; int nlroute_sock;
int nldiag_sock;
int client; int client;
zdtun_t *tun; zdtun_t *tun;