RDMA Write GPU Direct Example
This example extends the minimal RDMA Write program by moving both the source and destination buffers onto the GPU using HIP (ROCm), demonstrating a GPU-direct RDMA Write.
On the server side, the program:
- Accepts an RDMA connection and creates a Queue Pair.
- Allocates a 4 KB buffer in GPU memory using hipMalloc, initializes it with hipMemset, and registers this device pointer as an RDMA memory region (ibv_reg_mr with remote-write access).
- Exposes the GPU buffer’s (addr, rkey, len) to the client through RDMA connection private data.
- After the client performs an RDMA Write, the server copies the first 64 bytes from GPU to host using hipMemcpyDeviceToHost and prints the received string.
On the client side, the program:
- Resolves the server address and route, creates a Queue Pair, and connects via rdma_cm.
- Receives the server’s GPU memory metadata (addr, rkey, len) via private data.
- Allocates a GPU buffer with hipMalloc, copies the message "Hello RDMA from GPU" from host to device using hipMemcpyHostToDevice, and registers this GPU buffer as an RDMA memory region.
- Issues a single RDMA Write from its GPU buffer directly into the server’s GPU buffer and waits for completion by polling the send completion queue
The following are the actual implementations:
#include <rdma/rdma_cma.h>
#include <infiniband/verbs.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <arpa/inet.h>
#include <netdb.h>
// HIP (ROCm)
#include <hip/hip_runtime.h>
struct Info { uint64_t addr; uint32_t rkey, len; } __attribute__((packed));
int main(int argc,char**argv){
if(argc < 3){
fprintf(stderr,"Usage: %s <server_ip> <port>\n", argv[0]);
return 1;
}
const char* ip = argv[1];
int port = atoi(argv[2]);
struct rdma_event_channel *ec = rdma_create_event_channel();
struct rdma_cm_id *id;
struct rdma_cm_event *e;
if (rdma_create_id(ec, &id, NULL, RDMA_PS_TCP)) {
perror("create_id"); return 1;
}
struct addrinfo *res;
char ps[16];
snprintf(ps, sizeof(ps), "%d", port);
if (getaddrinfo(ip, ps, NULL, &res)) {
perror("getaddrinfo"); return 1;
}
if (rdma_resolve_addr(id, NULL, res->ai_addr, 2000)) {
perror("resolve_addr"); return 1;
}
if (rdma_get_cm_event(ec, &e)) { perror("event1"); return 1; }
rdma_ack_cm_event(e);
if (rdma_resolve_route(id, 2000)) {
perror("resolve_route"); return 1;
}
if (rdma_get_cm_event(ec, &e)) { perror("event2"); return 1; }
rdma_ack_cm_event(e);
struct ibv_qp_init_attr qa = {};
qa.qp_type = IBV_QPT_RC;
qa.cap.max_send_wr = qa.cap.max_recv_wr = 8;
qa.cap.max_send_sge = qa.cap.max_recv_sge = 1;
qa.sq_sig_all = 1;
if (rdma_create_qp(id, id->pd, &qa)) {
perror("create_qp"); return 1;
}
struct rdma_conn_param p = {};
if (rdma_connect(id, &p)) {
perror("connect"); return 1;
}
if (rdma_get_cm_event(ec, &e)) {
perror("event3"); return 1;
}
struct Info info;
memcpy(&info, e->param.conn.private_data, sizeof(info));
rdma_ack_cm_event(e);
const char *msg = "Hello RDMA from GPU";
size_t n = strlen(msg) + 1;
char *d_buf = NULL;
hipError_t herr;
herr = hipMalloc((void**)&d_buf, n);
if (herr != hipSuccess) {
fprintf(stderr, "hipMalloc failed: %s\n", hipGetErrorString(herr));
return 1;
}
herr = hipMemcpy(d_buf, msg, n, hipMemcpyHostToDevice);
if (herr != hipSuccess) {
fprintf(stderr, "hipMemcpy H2D failed: %s\n", hipGetErrorString(herr));
return 1;
}
struct ibv_mr *mr = ibv_reg_mr(
id->pd,
d_buf,
n,
IBV_ACCESS_LOCAL_WRITE
);
if (!mr) {
perror("ibv_reg_mr GPU");
return 1;
}
struct ibv_sge s = {
.addr = (uintptr_t)d_buf,
.length = (uint32_t)n,
.lkey = mr->lkey
};
struct ibv_send_wr wr = {}, *bad = NULL;
wr.opcode = IBV_WR_RDMA_WRITE;
wr.sg_list = &s;
wr.num_sge = 1;
wr.wr.rdma.remote_addr = info.addr;
wr.wr.rdma.rkey = info.rkey;
if (ibv_post_send(id->qp, &wr, &bad)) {
perror("ibv_post_send");
return 1;
}
struct ibv_wc wc;
while (ibv_poll_cq(id->qp->send_cq, 1, &wc) == 0) { /* busy wait */ }
if (wc.status != IBV_WC_SUCCESS) {
fprintf(stderr, "RDMA WRITE failed, wc.status=%d\n", wc.status);
return 1;
}
printf("[client][GPU] RDMA write done (%zu bytes)\n", n);
rdma_disconnect(id);
ibv_dereg_mr(mr);
hipFree(d_buf);
rdma_destroy_qp(id);
rdma_destroy_id(id);
rdma_destroy_event_channel(ec);
freeaddrinfo(res);
return 0;
}
#include <arpa/inet.h>
#include <infiniband/verbs.h>
#include <rdma/rdma_cma.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <hip/hip_runtime.h>
struct Info {
uint64_t addr;
uint32_t rkey, len;
} __attribute__((packed));
static void die(const char *m) {
perror(m);
exit(1);
}
int main(int argc, char **argv) {
if (argc < 2) {
fprintf(stderr, "Usage: %s <port>\n", argv[0]);
return 1;
}
int port = atoi(argv[1]);
struct rdma_event_channel *ec = rdma_create_event_channel();
struct rdma_cm_id *lid, *id;
struct rdma_cm_event *e;
struct sockaddr_in a = {0};
a.sin_family = AF_INET;
a.sin_port = htons(port);
if (rdma_create_id(ec, &lid, NULL, RDMA_PS_TCP))
die("create_id");
if (rdma_bind_addr(lid, (struct sockaddr *)&a))
die("bind");
if (rdma_listen(lid, 1))
die("listen");
printf("[server] listening on %d ...\n", port);
if (rdma_get_cm_event(ec, &e))
die("get_event");
id = e->id;
rdma_ack_cm_event(e);
struct ibv_qp_init_attr qa = {0};
qa.qp_type = IBV_QPT_RC;
qa.cap.max_send_wr = qa.cap.max_recv_wr = 8;
qa.cap.max_send_sge = qa.cap.max_recv_sge = 1;
qa.sq_sig_all = 1;
if (rdma_create_qp(id, id->pd, &qa))
die("create_qp");
size_t len = 4096;
char *d_buf = NULL;
hipError_t herr;
herr = hipMalloc((void**)&d_buf, len);
if (herr != hipSuccess) {
fprintf(stderr, "hipMalloc failed: %s\n", hipGetErrorString(herr));
return 1;
}
herr = hipMemset(d_buf, 0, len);
if (herr != hipSuccess) {
fprintf(stderr, "hipMemset failed: %s\n", hipGetErrorString(herr));
return 1;
}
struct ibv_mr *mr = ibv_reg_mr(
id->pd, d_buf, len, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
if (!mr)
die("reg_mr(GPU)");
struct Info info = {(uint64_t)d_buf, mr->rkey, (uint32_t)len};
struct rdma_conn_param p = {0};
p.private_data = &info;
p.private_data_len = sizeof(info);
if (rdma_accept(id, &p))
die("accept");
if (rdma_get_cm_event(ec, &e))
die("event2");
rdma_ack_cm_event(e);
sleep(2);
char host_buf[64] = {0};
herr = hipMemcpy(host_buf, d_buf, sizeof(host_buf), hipMemcpyDeviceToHost);
if (herr != hipSuccess) {
fprintf(stderr, "hipMemcpy D2H failed: %s\n", hipGetErrorString(herr));
return 1;
}
printf("[server] got: '%.*s'\n", 64, host_buf);
rdma_disconnect(id);
ibv_dereg_mr(mr);
hipFree(d_buf);
rdma_destroy_qp(id);
rdma_destroy_id(id);
rdma_destroy_id(lid);
rdma_destroy_event_channel(ec);
return 0;
}