diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 3bb213e2a967e8..1eaaf91eb978b5 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -1393,11 +1393,15 @@ static bool mptcp_pm_has_addr_attr_id(const struct nlattr *attr, int mptcp_pm_nl_add_addr_doit(struct sk_buff *skb, struct genl_info *info) { - struct nlattr *attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; struct pm_nl_pernet *pernet = genl_info_pm_nl(info); struct mptcp_pm_addr_entry addr, *entry; + struct nlattr *attr; int ret; + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ENDPOINT_ADDR)) + return -EINVAL; + + attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; ret = mptcp_pm_parse_entry(attr, info, true, &addr); if (ret < 0) return ret; @@ -1590,12 +1594,16 @@ static int mptcp_nl_remove_id_zero_address(struct net *net, int mptcp_pm_nl_del_addr_doit(struct sk_buff *skb, struct genl_info *info) { - struct nlattr *attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; struct pm_nl_pernet *pernet = genl_info_pm_nl(info); struct mptcp_pm_addr_entry addr, *entry; unsigned int addr_max; + struct nlattr *attr; int ret; + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ENDPOINT_ADDR)) + return -EINVAL; + + attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; ret = mptcp_pm_parse_entry(attr, info, false, &addr); if (ret < 0) return ret; @@ -1767,13 +1775,17 @@ int mptcp_nl_fill_addr(struct sk_buff *skb, int mptcp_pm_nl_get_addr(struct sk_buff *skb, struct genl_info *info) { - struct nlattr *attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; struct pm_nl_pernet *pernet = genl_info_pm_nl(info); struct mptcp_pm_addr_entry addr, *entry; struct sk_buff *msg; + struct nlattr *attr; void *reply; int ret; + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ENDPOINT_ADDR)) + return -EINVAL; + + attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; ret = mptcp_pm_parse_entry(attr, info, false, &addr); if (ret < 0) return ret; @@ -1989,18 +2001,20 @@ static int mptcp_nl_set_flags(struct net *net, int mptcp_pm_nl_set_flags(struct sk_buff *skb, struct genl_info *info) { struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }; - struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; u8 changed, mask = MPTCP_PM_ADDR_FLAG_BACKUP | MPTCP_PM_ADDR_FLAG_FULLMESH; struct net *net = sock_net(skb->sk); struct mptcp_pm_addr_entry *entry; struct pm_nl_pernet *pernet; + struct nlattr *attr; u8 lookup_by_id = 0; u8 bkup = 0; int ret; - pernet = pm_nl_get_pernet(net); + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_ADDR)) + return -EINVAL; + attr = info->attrs[MPTCP_PM_ATTR_ADDR]; ret = mptcp_pm_parse_entry(attr, info, false, &addr); if (ret < 0) return ret; @@ -2017,6 +2031,8 @@ int mptcp_pm_nl_set_flags(struct sk_buff *skb, struct genl_info *info) if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP) bkup = 1; + pernet = pm_nl_get_pernet(net); + spin_lock_bh(&pernet->lock); entry = lookup_by_id ? __lookup_addr_by_id(pernet, addr.addr.id) : __lookup_addr(pernet, &addr.addr); diff --git a/net/mptcp/pm_userspace.c b/net/mptcp/pm_userspace.c index 51943ca6ddb607..f5515b2305f41f 100644 --- a/net/mptcp/pm_userspace.c +++ b/net/mptcp/pm_userspace.c @@ -570,20 +570,24 @@ int mptcp_userspace_pm_set_flags(struct sk_buff *skb, struct genl_info *info) { struct mptcp_pm_addr_entry loc = { .addr = { .family = AF_UNSPEC }, }; struct mptcp_pm_addr_entry rem = { .addr = { .family = AF_UNSPEC }, }; - struct nlattr *attr_rem = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; - struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; struct mptcp_pm_addr_entry *entry; + struct nlattr *attr, *attr_rem; struct mptcp_sock *msk; int ret = -EINVAL; struct sock *sk; u8 bkup = 0; + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_ADDR) || + GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_ADDR_REMOTE)) + return ret; + msk = mptcp_userspace_pm_get_sock(info); if (!msk) return ret; sk = (struct sock *)msk; + attr = info->attrs[MPTCP_PM_ATTR_ADDR]; ret = mptcp_pm_parse_entry(attr, info, false, &loc); if (ret < 0) goto set_flags_err; @@ -595,6 +599,7 @@ int mptcp_userspace_pm_set_flags(struct sk_buff *skb, struct genl_info *info) goto set_flags_err; } + attr_rem = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; ret = mptcp_pm_parse_entry(attr_rem, info, false, &rem); if (ret < 0) goto set_flags_err; @@ -683,20 +688,24 @@ int mptcp_userspace_pm_dump_addr(struct sk_buff *msg, int mptcp_userspace_pm_get_addr(struct sk_buff *skb, struct genl_info *info) { - struct nlattr *attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; struct mptcp_pm_addr_entry addr, *entry; struct mptcp_sock *msk; struct sk_buff *msg; + struct nlattr *attr; int ret = -EINVAL; struct sock *sk; void *reply; + if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ENDPOINT_ADDR)) + return ret; + msk = mptcp_userspace_pm_get_sock(info); if (!msk) return ret; sk = (struct sock *)msk; + attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; ret = mptcp_pm_parse_entry(attr, info, false, &addr); if (ret < 0) goto out;