diff options
Diffstat (limited to 'net/netfilter/xt_socket.c')
-rw-r--r-- | net/netfilter/xt_socket.c | 31 |
1 files changed, 23 insertions, 8 deletions
diff --git a/net/netfilter/xt_socket.c b/net/netfilter/xt_socket.c index e092cb046326..281f4d65b6ea 100644 --- a/net/netfilter/xt_socket.c +++ b/net/netfilter/xt_socket.c @@ -143,10 +143,11 @@ static bool xt_socket_sk_is_transparent(struct sock *sk) } } -static struct sock *xt_socket_lookup_slow_v4(const struct sk_buff *skb, +struct sock *xt_socket_lookup_slow_v4(const struct sk_buff *skb, const struct net_device *indev) { const struct iphdr *iph = ip_hdr(skb); + struct sock *sk = skb->sk; __be32 uninitialized_var(daddr), uninitialized_var(saddr); __be16 uninitialized_var(dport), uninitialized_var(sport); u8 uninitialized_var(protocol); @@ -197,9 +198,16 @@ static struct sock *xt_socket_lookup_slow_v4(const struct sk_buff *skb, } #endif - return xt_socket_get_sock_v4(dev_net(skb->dev), protocol, saddr, daddr, - sport, dport, indev); + if (sk) + atomic_inc(&sk->sk_refcnt); + else + sk = xt_socket_get_sock_v4(dev_net(skb->dev), protocol, + saddr, daddr, sport, dport, + indev); + + return sk; } +EXPORT_SYMBOL(xt_socket_lookup_slow_v4); static bool socket_match(const struct sk_buff *skb, struct xt_action_param *par, @@ -226,8 +234,7 @@ socket_match(const struct sk_buff *skb, struct xt_action_param *par, if (info->flags & XT_SOCKET_TRANSPARENT) transparent = xt_socket_sk_is_transparent(sk); - if (sk != skb->sk) - sock_gen_put(sk); + sock_gen_put(sk); if (wildcard || !transparent) sk = NULL; @@ -330,9 +337,10 @@ xt_socket_get_sock_v6(struct net *net, const u8 protocol, return NULL; } -static struct sock *xt_socket_lookup_slow_v6(const struct sk_buff *skb, +struct sock *xt_socket_lookup_slow_v6(const struct sk_buff *skb, const struct net_device *indev) { + struct sock *sk = skb->sk; __be16 uninitialized_var(dport), uninitialized_var(sport); const struct in6_addr *daddr = NULL, *saddr = NULL; struct ipv6hdr *iph = ipv6_hdr(skb); @@ -366,9 +374,16 @@ static struct sock *xt_socket_lookup_slow_v6(const struct sk_buff *skb, return NULL; } - return xt_socket_get_sock_v6(dev_net(skb->dev), tproto, saddr, daddr, - sport, dport, indev); + if (sk) + atomic_inc(&sk->sk_refcnt); + else + sk = xt_socket_get_sock_v6(dev_net(skb->dev), tproto, + saddr, daddr, sport, dport, + indev); + + return sk; } +EXPORT_SYMBOL(xt_socket_lookup_slow_v6); static bool socket_mt6_v1_v2(const struct sk_buff *skb, struct xt_action_param *par) |