1 // SPDX-License-Identifier: GPL-2.0 or BSD-3-Clause
2 /*
3 * Copyright(c) 2016 Intel Corporation.
4 */
5
6 #include <linux/slab.h>
7 #include <linux/sched.h>
8 #include <linux/rculist.h>
9 #include <rdma/rdma_vt.h>
10 #include <rdma/rdmavt_qp.h>
11
12 #include "mcast.h"
13
14 /**
15 * rvt_driver_mcast_init - init resources for multicast
16 * @rdi: rvt dev struct
17 *
18 * This is per device that registers with rdmavt
19 */
rvt_driver_mcast_init(struct rvt_dev_info * rdi)20 void rvt_driver_mcast_init(struct rvt_dev_info *rdi)
21 {
22 /*
23 * Anything that needs setup for multicast on a per driver or per rdi
24 * basis should be done in here.
25 */
26 spin_lock_init(&rdi->n_mcast_grps_lock);
27 }
28
29 /**
30 * rvt_mcast_qp_alloc - alloc a struct to link a QP to mcast GID struct
31 * @qp: the QP to link
32 */
rvt_mcast_qp_alloc(struct rvt_qp * qp)33 static struct rvt_mcast_qp *rvt_mcast_qp_alloc(struct rvt_qp *qp)
34 {
35 struct rvt_mcast_qp *mqp;
36
37 mqp = kmalloc(sizeof(*mqp), GFP_KERNEL);
38 if (!mqp)
39 goto bail;
40
41 mqp->qp = qp;
42 rvt_get_qp(qp);
43
44 bail:
45 return mqp;
46 }
47
rvt_mcast_qp_free(struct rvt_mcast_qp * mqp)48 static void rvt_mcast_qp_free(struct rvt_mcast_qp *mqp)
49 {
50 struct rvt_qp *qp = mqp->qp;
51
52 /* Notify hfi1_destroy_qp() if it is waiting. */
53 rvt_put_qp(qp);
54
55 kfree(mqp);
56 }
57
58 /**
59 * rvt_mcast_alloc - allocate the multicast GID structure
60 * @mgid: the multicast GID
61 * @lid: the muilticast LID (host order)
62 *
63 * A list of QPs will be attached to this structure.
64 */
rvt_mcast_alloc(union ib_gid * mgid,u16 lid)65 static struct rvt_mcast *rvt_mcast_alloc(union ib_gid *mgid, u16 lid)
66 {
67 struct rvt_mcast *mcast;
68
69 mcast = kzalloc(sizeof(*mcast), GFP_KERNEL);
70 if (!mcast)
71 goto bail;
72
73 mcast->mcast_addr.mgid = *mgid;
74 mcast->mcast_addr.lid = lid;
75
76 INIT_LIST_HEAD(&mcast->qp_list);
77 init_waitqueue_head(&mcast->wait);
78 atomic_set(&mcast->refcount, 0);
79
80 bail:
81 return mcast;
82 }
83
rvt_mcast_free(struct rvt_mcast * mcast)84 static void rvt_mcast_free(struct rvt_mcast *mcast)
85 {
86 struct rvt_mcast_qp *p, *tmp;
87
88 list_for_each_entry_safe(p, tmp, &mcast->qp_list, list)
89 rvt_mcast_qp_free(p);
90
91 kfree(mcast);
92 }
93
94 /**
95 * rvt_mcast_find - search the global table for the given multicast GID/LID
96 * NOTE: It is valid to have 1 MLID with multiple MGIDs. It is not valid
97 * to have 1 MGID with multiple MLIDs.
98 * @ibp: the IB port structure
99 * @mgid: the multicast GID to search for
100 * @lid: the multicast LID portion of the multicast address (host order)
101 *
102 * The caller is responsible for decrementing the reference count if found.
103 *
104 * Return: NULL if not found.
105 */
rvt_mcast_find(struct rvt_ibport * ibp,union ib_gid * mgid,u16 lid)106 struct rvt_mcast *rvt_mcast_find(struct rvt_ibport *ibp, union ib_gid *mgid,
107 u16 lid)
108 {
109 struct rb_node *n;
110 unsigned long flags;
111 struct rvt_mcast *found = NULL;
112
113 spin_lock_irqsave(&ibp->lock, flags);
114 n = ibp->mcast_tree.rb_node;
115 while (n) {
116 int ret;
117 struct rvt_mcast *mcast;
118
119 mcast = rb_entry(n, struct rvt_mcast, rb_node);
120
121 ret = memcmp(mgid->raw, mcast->mcast_addr.mgid.raw,
122 sizeof(*mgid));
123 if (ret < 0) {
124 n = n->rb_left;
125 } else if (ret > 0) {
126 n = n->rb_right;
127 } else {
128 /* MGID/MLID must match */
129 if (mcast->mcast_addr.lid == lid) {
130 atomic_inc(&mcast->refcount);
131 found = mcast;
132 }
133 break;
134 }
135 }
136 spin_unlock_irqrestore(&ibp->lock, flags);
137 return found;
138 }
139 EXPORT_SYMBOL(rvt_mcast_find);
140
141 /*
142 * rvt_mcast_add - insert mcast GID into table and attach QP struct
143 * @mcast: the mcast GID table
144 * @mqp: the QP to attach
145 *
146 * Return: zero if both were added. Return EEXIST if the GID was already in
147 * the table but the QP was added. Return ESRCH if the QP was already
148 * attached and neither structure was added. Return EINVAL if the MGID was
149 * found, but the MLID did NOT match.
150 */
rvt_mcast_add(struct rvt_dev_info * rdi,struct rvt_ibport * ibp,struct rvt_mcast * mcast,struct rvt_mcast_qp * mqp)151 static int rvt_mcast_add(struct rvt_dev_info *rdi, struct rvt_ibport *ibp,
152 struct rvt_mcast *mcast, struct rvt_mcast_qp *mqp)
153 {
154 struct rb_node **n = &ibp->mcast_tree.rb_node;
155 struct rb_node *pn = NULL;
156 int ret;
157
158 spin_lock_irq(&ibp->lock);
159
160 while (*n) {
161 struct rvt_mcast *tmcast;
162 struct rvt_mcast_qp *p;
163
164 pn = *n;
165 tmcast = rb_entry(pn, struct rvt_mcast, rb_node);
166
167 ret = memcmp(mcast->mcast_addr.mgid.raw,
168 tmcast->mcast_addr.mgid.raw,
169 sizeof(mcast->mcast_addr.mgid));
170 if (ret < 0) {
171 n = &pn->rb_left;
172 continue;
173 }
174 if (ret > 0) {
175 n = &pn->rb_right;
176 continue;
177 }
178
179 if (tmcast->mcast_addr.lid != mcast->mcast_addr.lid) {
180 ret = EINVAL;
181 goto bail;
182 }
183
184 /* Search the QP list to see if this is already there. */
185 list_for_each_entry_rcu(p, &tmcast->qp_list, list) {
186 if (p->qp == mqp->qp) {
187 ret = ESRCH;
188 goto bail;
189 }
190 }
191 if (tmcast->n_attached ==
192 rdi->dparms.props.max_mcast_qp_attach) {
193 ret = ENOMEM;
194 goto bail;
195 }
196
197 tmcast->n_attached++;
198
199 list_add_tail_rcu(&mqp->list, &tmcast->qp_list);
200 ret = EEXIST;
201 goto bail;
202 }
203
204 spin_lock(&rdi->n_mcast_grps_lock);
205 if (rdi->n_mcast_grps_allocated == rdi->dparms.props.max_mcast_grp) {
206 spin_unlock(&rdi->n_mcast_grps_lock);
207 ret = ENOMEM;
208 goto bail;
209 }
210
211 rdi->n_mcast_grps_allocated++;
212 spin_unlock(&rdi->n_mcast_grps_lock);
213
214 mcast->n_attached++;
215
216 list_add_tail_rcu(&mqp->list, &mcast->qp_list);
217
218 atomic_inc(&mcast->refcount);
219 rb_link_node(&mcast->rb_node, pn, n);
220 rb_insert_color(&mcast->rb_node, &ibp->mcast_tree);
221
222 ret = 0;
223
224 bail:
225 spin_unlock_irq(&ibp->lock);
226
227 return ret;
228 }
229
230 /**
231 * rvt_attach_mcast - attach a qp to a multicast group
232 * @ibqp: Infiniband qp
233 * @gid: multicast guid
234 * @lid: multicast lid
235 *
236 * Return: 0 on success
237 */
rvt_attach_mcast(struct ib_qp * ibqp,union ib_gid * gid,u16 lid)238 int rvt_attach_mcast(struct ib_qp *ibqp, union ib_gid *gid, u16 lid)
239 {
240 struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
241 struct rvt_dev_info *rdi = ib_to_rvt(ibqp->device);
242 struct rvt_ibport *ibp = rdi->ports[qp->port_num - 1];
243 struct rvt_mcast *mcast;
244 struct rvt_mcast_qp *mqp;
245 int ret = -ENOMEM;
246
247 if (ibqp->qp_num <= 1 || qp->state == IB_QPS_RESET)
248 return -EINVAL;
249
250 /*
251 * Allocate data structures since its better to do this outside of
252 * spin locks and it will most likely be needed.
253 */
254 mcast = rvt_mcast_alloc(gid, lid);
255 if (!mcast)
256 return -ENOMEM;
257
258 mqp = rvt_mcast_qp_alloc(qp);
259 if (!mqp)
260 goto bail_mcast;
261
262 switch (rvt_mcast_add(rdi, ibp, mcast, mqp)) {
263 case ESRCH:
264 /* Neither was used: OK to attach the same QP twice. */
265 ret = 0;
266 goto bail_mqp;
267 case EEXIST: /* The mcast wasn't used */
268 ret = 0;
269 goto bail_mcast;
270 case ENOMEM:
271 /* Exceeded the maximum number of mcast groups. */
272 ret = -ENOMEM;
273 goto bail_mqp;
274 case EINVAL:
275 /* Invalid MGID/MLID pair */
276 ret = -EINVAL;
277 goto bail_mqp;
278 default:
279 break;
280 }
281
282 return 0;
283
284 bail_mqp:
285 rvt_mcast_qp_free(mqp);
286
287 bail_mcast:
288 rvt_mcast_free(mcast);
289
290 return ret;
291 }
292
293 /**
294 * rvt_detach_mcast - remove a qp from a multicast group
295 * @ibqp: Infiniband qp
296 * @gid: multicast guid
297 * @lid: multicast lid
298 *
299 * Return: 0 on success
300 */
rvt_detach_mcast(struct ib_qp * ibqp,union ib_gid * gid,u16 lid)301 int rvt_detach_mcast(struct ib_qp *ibqp, union ib_gid *gid, u16 lid)
302 {
303 struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
304 struct rvt_dev_info *rdi = ib_to_rvt(ibqp->device);
305 struct rvt_ibport *ibp = rdi->ports[qp->port_num - 1];
306 struct rvt_mcast *mcast = NULL;
307 struct rvt_mcast_qp *p, *tmp, *delp = NULL;
308 struct rb_node *n;
309 int last = 0;
310 int ret = 0;
311
312 if (ibqp->qp_num <= 1)
313 return -EINVAL;
314
315 spin_lock_irq(&ibp->lock);
316
317 /* Find the GID in the mcast table. */
318 n = ibp->mcast_tree.rb_node;
319 while (1) {
320 if (!n) {
321 spin_unlock_irq(&ibp->lock);
322 return -EINVAL;
323 }
324
325 mcast = rb_entry(n, struct rvt_mcast, rb_node);
326 ret = memcmp(gid->raw, mcast->mcast_addr.mgid.raw,
327 sizeof(*gid));
328 if (ret < 0) {
329 n = n->rb_left;
330 } else if (ret > 0) {
331 n = n->rb_right;
332 } else {
333 /* MGID/MLID must match */
334 if (mcast->mcast_addr.lid != lid) {
335 spin_unlock_irq(&ibp->lock);
336 return -EINVAL;
337 }
338 break;
339 }
340 }
341
342 /* Search the QP list. */
343 list_for_each_entry_safe(p, tmp, &mcast->qp_list, list) {
344 if (p->qp != qp)
345 continue;
346 /*
347 * We found it, so remove it, but don't poison the forward
348 * link until we are sure there are no list walkers.
349 */
350 list_del_rcu(&p->list);
351 mcast->n_attached--;
352 delp = p;
353
354 /* If this was the last attached QP, remove the GID too. */
355 if (list_empty(&mcast->qp_list)) {
356 rb_erase(&mcast->rb_node, &ibp->mcast_tree);
357 last = 1;
358 }
359 break;
360 }
361
362 spin_unlock_irq(&ibp->lock);
363 /* QP not attached */
364 if (!delp)
365 return -EINVAL;
366
367 /*
368 * Wait for any list walkers to finish before freeing the
369 * list element.
370 */
371 wait_event(mcast->wait, atomic_read(&mcast->refcount) <= 1);
372 rvt_mcast_qp_free(delp);
373
374 if (last) {
375 atomic_dec(&mcast->refcount);
376 wait_event(mcast->wait, !atomic_read(&mcast->refcount));
377 rvt_mcast_free(mcast);
378 spin_lock_irq(&rdi->n_mcast_grps_lock);
379 rdi->n_mcast_grps_allocated--;
380 spin_unlock_irq(&rdi->n_mcast_grps_lock);
381 }
382
383 return 0;
384 }
385
386 /**
387 * rvt_mcast_tree_empty - determine if any qps are attached to any mcast group
388 * @rdi: rvt dev struct
389 *
390 * Return: in use count
391 */
rvt_mcast_tree_empty(struct rvt_dev_info * rdi)392 int rvt_mcast_tree_empty(struct rvt_dev_info *rdi)
393 {
394 int i;
395 int in_use = 0;
396
397 for (i = 0; i < rdi->dparms.nports; i++)
398 if (rdi->ports[i]->mcast_tree.rb_node)
399 in_use++;
400 return in_use;
401 }
402