17 typedef int (*
_compar_fn_t)(
const void * r1,
const void * r2,
size_t rsize);
18 typedef void (*
_bisect_fn_t)(
void * r,
const void * r1,
const void * r2,
size_t rsize);
31 #define DEFTYPE(type) \
32 static int _compar_radix_ ## type ( \
36 return (signed) (*u1 > *u2) - (signed) (*u1 < *u2); \
38 static void _bisect_radix_ ## type ( \
43 *u = *u1 + ((*u2 - *u1) >> 1); \
48 static
int _compar_radix(const
void * r1, const
void * r2,
size_t rsize,
int dir) {
51 const unsigned char * u1 = (
const unsigned char *) r1;
52 const unsigned char * u2 = (
const unsigned char *) r2;
57 for(i = 0; i < rsize; i ++) {
58 if(*u1 < *u2)
return -1;
59 if(*u1 > *u2)
return 1;
68 const uint64_t * u1 = (
const uint64_t *) r1;
69 const uint64_t * u2 = (
const uint64_t *) r2;
71 u1 = (
const uint64_t *) ((
const char*) u1 + rsize - 8);
72 u2 = (
const uint64_t *) ((
const char*) u2 + rsize - 8);
74 for(i = 0; i < rsize; i += 8) {
75 if(*u1 < *u2)
return -1;
76 if(*u1 > *u2)
return 1;
94 static void _bisect_radix(
void * r,
const void * r1,
const void * r2,
size_t rsize,
int dir) {
96 const unsigned char * u1 = (
const unsigned char *) r1;
97 const unsigned char * u2 = (
const unsigned char *) r2;
98 unsigned char * u = (
unsigned char *) r;
99 unsigned int carry = 0;
105 for(i = 0; i < rsize; i ++) {
106 unsigned int tmp = (
unsigned int) *u2 + *u1 + carry;
107 if(tmp >= 256) carry = 1;
109 *u = tmp % (UINT8_MAX+1);
115 for(i = 0; i < rsize; i ++) {
116 unsigned int tmp = *u + carry * 256;
134 void (*radix)(
const void * ptr,
void * radix,
void * arg),
142 } be_detect = {0x01020304};
163 if(be_detect.c[0] != 1) {
220 if (
nmemb == 0)
return -1;
222 char tmpradix[d->
rsize];
224 ptrdiff_t right =
nmemb - 1;
237 while(right > left + 1) {
238 ptrdiff_t mid = ((right - left + 1) >> 1) + left;
261 if (
nmemb == 0)
return -1;
263 char tmpradix[d->
rsize];
265 ptrdiff_t right =
nmemb - 1;
278 while(right > left + 1) {
279 ptrdiff_t mid = ((right - left + 1) >> 1) + left;
305 static void _histogram(
char *
P,
int Plength,
void * mybase,
size_t mynmemb,
306 ptrdiff_t * myCLT, ptrdiff_t * myCLE,
312 for(it = 0; it < Plength; it ++) {
314 ptrdiff_t offset = myCLT[it];
316 ((
char*) mybase) + offset * d->
size,
320 myCLT[it + 1] = mynmemb;
324 for(it = 0; it < Plength; it ++) {
326 ptrdiff_t offset = myCLE[it];
328 ((
char*) mybase) + offset * d->
size,
332 myCLE[it + 1] = mynmemb;
345 char * Pmin,
char * Pmax,
int Plength,
348 memset(pi->
stable, 0, Plength *
sizeof(
int));
350 memset(pi->
narrow, 0, Plength *
sizeof(
int));
353 memset(pi->
Pleft, 0, Plength * d->
rsize *
sizeof(
char));
355 memset(pi->
Pright, 0, Plength * d->
rsize *
sizeof(
char));
359 for(i = 0; i < pi->
Plength; i ++) {
381 for(i = 0; i < pi->
Plength; i ++) {
382 if(pi->
stable[i])
continue;
403 printf(
"bisect %d %u %u %u\n", i, *(
int*) &
P[i * d->
rsize],
414 for(i = 0; i < pi->
Plength; i ++) {
415 printf(
"P %d stable %d narrow %d\n",
419 for(i = 0; i < pi->
Plength; i ++) {
435 ptrdiff_t * C, ptrdiff_t * CLT, ptrdiff_t * CLE) {
439 for(i = 0; i < pi->
Plength + 1; i ++) {
440 printf(
"counts %d LT %ld C %ld LE %ld\n",
441 i, CLT[i], C[i], CLE[i]);
444 for(i = 0; i < pi->
Plength; i ++) {
445 if( CLT[i + 1] < C[i + 1] && C[i + 1] <= CLE[i + 1]) {
449 if(CLT[i + 1] >= C[i + 1]) {
495 void * myoutbase,
size_t myoutnmemb,
501 MPI_Comm_size(comm, &o->
NTask);
513 endrun(4,
"total number of items in the item does not match the input %ld != %ld\n", o->
outnmemb, o->
nmemb);
530 char * Pmax,
char * Pmin,
560 MPI_Comm_size(comm, &
NTask);
564 size_t current_size = 0;
565 size_t current_outsize = 0;
566 int current_color = 0;
568 for(i = 0; i <
NTask; i ++) {
569 current_size += sizes[i];
571 lastcolor = current_color;
577 if(current_size > glocalsize || current_outsize > glocalsize) {
588 *ncolor = lastcolor + 1;
593 _collect_sizes(
size_t localsize,
size_t * sizes,
size_t * myoffset, MPI_Comm comm)
598 MPI_Comm_size(comm, &
NTask);
606 MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, sizes, 1,
MPI_TYPE_PTRDIFF, comm);
611 (*myoffset) += sizes[i];
641 MPIU_GetLoc(
const void * base, MPI_Datatype type, MPI_Op op, MPI_Comm comm)
645 MPI_Type_get_extent(type, &lb, &elsize);
647 void * tmp = malloc(elsize);
649 MPI_Allreduce(base, tmp, 1, type, op, comm);
653 MPI_Comm_size(comm, &
NTask);
657 if (memcmp(base, tmp, elsize) == 0) {
664 MPI_Allreduce(&rank, &ret, 1, MPI_INT, MPI_MIN, comm);
674 MPI_Comm_size(comm, &
NTask);
699 MPI_Comm_rank(descr->
Group, &rank);
717 MPI_Comm_free(&descr->
Segment);
718 MPI_Comm_free(&descr->
Group);
719 MPI_Comm_free(&descr->
Leaders);
730 tmr->
time = MPI_Wtime();
732 tmr->
name[19] =
'\0';
763 if(0 == strncmp(tmr->
name,
"END", 20))
776 void (*radix)(
const void * ptr,
void * radix,
void * arg),
780 const int line,
const char * file)
784 size, radix, rsize, arg, comm, line, file);
791 MPIU_Scatter (MPI_Comm comm,
int root,
const void * sendbuffer,
void * recvbuffer,
int nrecv,
size_t elsize,
int * totalnsend);
793 MPIU_Gather (MPI_Comm comm,
int root,
const void * sendbuffer,
void * recvbuffer,
int nsend,
size_t elsize,
int * totalnrecv);
796 checksum(
void * base,
size_t nbytes, MPI_Comm comm)
799 char * ptr = (
char *) base;
801 for(i = 0; i < nbytes; i ++) {
804 MPI_Allreduce(MPI_IN_PLACE, &sum, 1, MPI_LONG, MPI_SUM, comm);
810 void * myoutbase,
size_t myoutnmemb,
812 void (*radix)(
const void * ptr,
void * radix,
void * arg),
820 if(MPI_SUCCESS != MPI_Type_match_size(MPI_TYPECLASS_INTEGER,
sizeof(ptrdiff_t), &
MPI_TYPE_PTRDIFF))
821 endrun(3,
"Ptrdiff size %ld not recognised\n",
sizeof(ptrdiff_t));
826 uint64_t sum1 =
checksum(mybase, elsize * mynmemb, comm);
830 MPI_Comm_size(comm, &
NTask);
833 if(elsize > 8 && elsize % 8 != 0) {
835 endrun(12,
"MPSort: element size is large (%d) but not aligned to 8 bytes. "
836 "This is known to frequently trigger MPI bugs. "
837 "Caller site: %s:%d\n",
841 if(rsize > 8 && rsize % 8 != 0) {
843 endrun(12,
"MPSort: radix size is large (%d) but not aligned to 8 bytes. "
844 "This is known to frequently trigger MPI bugs. "
845 "Caller site: %s:%d\n",
854 size_t avgsegsize =
NTask;
855 if (avgsegsize * elsize > 4 * 1024 * 1024) {
857 avgsegsize = 4 * 1024 * 1024 / elsize;
860 message(0,
"MPSort: gathering all data to a single rank for sorting due to MPSORT_REQUIRE_GATHER_SORT. "
861 "Total number of items is %ld. Caller site: %s:%d\n",
868 message(0,
"MPSort: disable gathering data into larger chunks due to MPSORT_DISABLE_GATHER_SORT. "
869 "Caller site: %s:%d\n",
879 void * mysegmentbase = NULL;
880 void * myoutsegmentbase = NULL;
881 size_t mysegmentnmemb;
882 size_t myoutsegmentnmemb;
886 MPI_Comm_size(seggrp->
Group, &groupsize);
887 MPI_Comm_rank(seggrp->
Group, &grouprank);
894 mysegmentbase =
mymalloc(
"segmentbase", mysegmentnmemb * elsize);
895 myoutsegmentbase =
mymalloc(
"outsegment", myoutsegmentnmemb * elsize);
899 mysegmentbase = mybase;
900 myoutsegmentbase = myoutbase;
934 if(mysegmentbase !=
mybase)
942 endrun(5,
"Data changed after sorting; checksum mismatch.\n");
947 MPIU_Gather (MPI_Comm
comm,
int root,
const void * sendbuffer,
void * recvbuffer,
int nsend,
size_t elsize,
int * totalnrecv)
956 MPI_Type_contiguous(elsize, MPI_BYTE, &dtype);
957 MPI_Type_commit(&dtype);
962 MPI_Gather(&nsend, 1, MPI_INT, recvcount, 1, MPI_INT, root,
comm);
965 for(i = 1; i <=
NTask; i ++) {
966 rdispls[i] = rdispls[i - 1] + recvcount[i - 1];
971 *totalnrecv = rdispls[
NTask];
977 MPI_Gatherv(sendbuffer, nsend, dtype, recvbuffer, recvcount, rdispls, dtype, root,
comm);
981 MPI_Type_free(&dtype);
985 MPIU_Scatter (MPI_Comm
comm,
int root,
const void * sendbuffer,
void * recvbuffer,
int nrecv,
size_t elsize,
int * totalnsend)
993 MPI_Type_contiguous(elsize, MPI_BYTE, &dtype);
994 MPI_Type_commit(&dtype);
1000 MPI_Gather(&nrecv, 1, MPI_INT, sendcount, 1, MPI_INT, root,
comm);
1003 for(i = 1; i <=
NTask; i ++) {
1004 sdispls[i] = sdispls[i - 1] + sendcount[i - 1];
1009 *totalnsend = sdispls[
NTask];
1014 MPI_Scatterv(sendbuffer, sendcount, sdispls, dtype, recvbuffer, nrecv, dtype, root,
comm);
1018 MPI_Type_free(&dtype);
1024 ptrdiff_t * myC = (ptrdiff_t *)
mymalloc(
"myhistC", (o.
NTask + 1) *
sizeof(ptrdiff_t));
1027 ptrdiff_t * C = (ptrdiff_t *)
mymalloc(
"histC", (o.
NTask + 1) *
sizeof(ptrdiff_t));
1029 ptrdiff_t * myCLT = (ptrdiff_t *)
mymalloc(
"myhistC", (o.
NTask + 1) *
sizeof(ptrdiff_t));
1030 ptrdiff_t * CLT = (ptrdiff_t *)
mymalloc(
"histCLT", (o.
NTask + 1) *
sizeof(ptrdiff_t));
1032 ptrdiff_t * myCLE = (ptrdiff_t *)
mymalloc(
"myhistCLE", (o.
NTask + 1) *
sizeof(ptrdiff_t));
1033 ptrdiff_t * CLE = (ptrdiff_t *)
mymalloc(
"CLE", (o.
NTask + 1) *
sizeof(ptrdiff_t));
1045 MPI_Barrier(o.
comm);
1069 MPI_Allreduce(myCLT, CLT, o.
NTask + 1,
1071 MPI_Allreduce(myCLE, CLE, o.
NTask + 1,
1075 snprintf(bisectnum, 20,
"bisect%04d", iter);
1082 for(k = 0; k < o.
NTask; k ++) {
1083 MPI_Barrier(o.
comm);
1087 printf(
"P (%d): PMin %d PMax %d P ",
1092 for(i = 0; i < o.
NTask - 1; i ++) {
1093 printf(
" %d ", ((
int*)
P) [i]);
1098 for(i = 0; i < o.
NTask + 1; i ++) {
1099 printf(
"%ld ", C[i]);
1103 for(i = 0; i < o.
NTask + 1; i ++) {
1104 printf(
"%ld ", CLT[i]);
1108 for(i = 0; i < o.
NTask + 1; i ++) {
1109 printf(
"%ld ", CLE[i]);
1127 ptrdiff_t * myT_C = (ptrdiff_t *)
mymalloc(
"myhistT_C", (o.
NTask) *
sizeof(ptrdiff_t));
1128 ptrdiff_t * myT_CLT = (ptrdiff_t *)
mymalloc(
"myhistCLT", (o.
NTask) *
sizeof(ptrdiff_t));
1129 ptrdiff_t * myT_CLE = (ptrdiff_t *)
mymalloc(
"myhistCLE", (o.
NTask) *
sizeof(ptrdiff_t));
1158 for(i = 0;i < o.
NTask; i ++) {
1160 MPI_Barrier(o.
comm);
1162 for(j = 0; j < o.
NTask + 1; j ++) {
1163 printf(
"%d %d %d, ",
1188 for(i = 0; i < o.
NTask; i ++) {
1189 SendCount[i] = myC[i + 1] - myC[i];
1192 MPI_Alltoall(SendCount, 1, MPI_INT,
1193 RecvCount, 1, MPI_INT, o.
comm);
1197 size_t totrecv = RecvCount[0];
1198 for(i = 1; i < o.
NTask; i ++) {
1199 SendDispl[i] = SendDispl[i - 1] + SendCount[i - 1];
1200 RecvDispl[i] = RecvDispl[i - 1] + RecvCount[i - 1];
1201 if(SendDispl[i] != myC[i]) {
1202 endrun(7,
"SendDispl error\n");
1204 totrecv += RecvCount[i];
1212 for(k = 0; k < o.
NTask; k ++) {
1213 MPI_Barrier(o.
comm);
1218 for(i = 0; i < o.
NTask - 1; i ++) {
1219 printf(
"%d ", ((
int*)
P) [i]);
1224 for(i = 0; i < o.
NTask + 1; i ++) {
1225 printf(
"%d ", C[i]);
1229 for(i = 0; i < o.
NTask + 1; i ++) {
1230 printf(
"%d ", CLT[i]);
1234 for(i = 0; i < o.
NTask + 1; i ++) {
1235 printf(
"%d ", CLE[i]);
1240 for(i = 0; i < o.
NTask + 1; i ++) {
1241 printf(
"%d ", myC[i]);
1244 printf(
"MyCLT (%d): ", o.
ThisTask);
1245 for(i = 0; i < o.
NTask + 1; i ++) {
1246 printf(
"%d ", myCLT[i]);
1250 printf(
"MyCLE (%d): ", o.
ThisTask);
1251 for(i = 0; i < o.
NTask + 1; i ++) {
1252 printf(
"%d ", myCLE[i]);
1256 printf(
"Send Count(%d): ", o.
ThisTask);
1257 for(i = 0; i < o.
NTask; i ++) {
1258 printf(
"%d ", SendCount[i]);
1261 printf(
"My data(%d): ", o.
ThisTask);
1262 for(i = 0; i < mynmemb; i ++) {
1263 printf(
"%d ", ((
int*) mybase)[i]);
1290 MPI_Barrier(o.
comm);
1295 MPI_Barrier(o.
comm);
1306 char * Pmax,
char * Pmin,
1310 memset(Pmax, 0,
d->
rsize);
1311 memset(Pmin, -1,
d->
rsize);
1317 size_t * eachoutnmemb =
ta_malloc(
"eachoutnmemb",
size_t, o->
NTask);
1326 memset(myPmin, 0,
d->
rsize);
1327 memset(myPmax, 0,
d->
rsize);
1341 for(i = 0; i < o->
NTask; i ++) {
1342 C[i + 1] = C[i] + eachoutnmemb[i];
1343 if(eachnmemb[i] == 0)
continue;
1363 ptrdiff_t * myT_CLT,
1364 ptrdiff_t * myT_CLE,
1372 for(i = 0; i <
NTask; i ++) {
1373 myT_C[i] = myT_CLT[i];
1393 for(j = 0; j <
NTask; j ++) {
1394 ptrdiff_t recvcount = myT_C[j];
1398 ptrdiff_t deficit = C[
ThisTask + 1] - sure;
1400 for(j = 0; j <
NTask; j ++) {
1402 if(deficit == 0)
break;
1404 endrun(10,
"More items than there should be at j=%d: deficit=%ld\n (C: %ld sure %ld)", j, deficit, C[
ThisTask+1], sure);
1407 ptrdiff_t supply = myT_CLE[j] - myT_C[j];
1409 endrun(10,
"Less items than there should be at j=%d: supply=%ld (myTCLE %ld myTC %ld)\n", j, supply, myT_CLE[j], myT_C[j]);
1411 if(supply <= deficit) {
1415 myT_C[j] += deficit;
1426 static int _mpsort_env_parsed = 0;
1427 if(_mpsort_env_parsed)
return;
1429 _mpsort_env_parsed = 1;
1430 if(getenv(
"MPSORT_DISABLE_GATHER_SORT"))
1432 if(getenv(
"MPSORT_REQUIRE_GATHER_SORT "))
void message(int where, const char *fmt,...)
void endrun(int where, const char *fmt,...)
static int MPIU_GetLoc(const void *base, MPI_Datatype type, MPI_Op op, MPI_Comm comm)
void(* _bisect_fn_t)(void *r, const void *r1, const void *r2, size_t rsize)
static int _compar_radix_be_u8(const void *r1, const void *r2, size_t rsize)
static int _compar_radix_le_u8(const void *r1, const void *r2, size_t rsize)
int mpsort_mpi_has_options(int options)
static uint64_t checksum(void *base, size_t nbytes, MPI_Comm comm)
static void piter_accept(struct piter *pi, char *P, ptrdiff_t *C, ptrdiff_t *CLT, ptrdiff_t *CLE)
static size_t _collect_sizes(size_t localsize, size_t *sizes, size_t *myoffset, MPI_Comm comm)
static void _mpsort_mpi_parse_env()
static void MPIU_Gather(MPI_Comm comm, int root, const void *sendbuffer, void *recvbuffer, int nsend, size_t elsize, int *totalnrecv)
void mpsort_mpi_impl(void *mybase, size_t mynmemb, size_t size, void(*radix)(const void *ptr, void *radix, void *arg), size_t rsize, void *arg, MPI_Comm comm, const int line, const char *file)
static int _compar_radix_u8(const void *r1, const void *r2, size_t rsize, int dir)
static void radix_sort(void *base, size_t nmemb, size_t size, void(*radix)(const void *ptr, void *radix, void *arg), size_t rsize, void *arg)
static int _compar_radix_le(const void *r1, const void *r2, size_t rsize)
static ptrdiff_t _bsearch_last_le(void *P, void *base, size_t nmemb, struct crstruct *d)
static struct TIMERS _TIMERS
static void _create_segment_group(struct SegmentGroupDescr *descr, size_t *sizes, size_t avgsegsize, int Ngroup, MPI_Comm comm)
static void _find_Pmax_Pmin_C(void *mybase, size_t mynmemb, size_t myoutnmemb, char *Pmax, char *Pmin, ptrdiff_t *C, struct crstruct *d, struct crmpistruct *o)
static int _solve_for_layout_mpi(int NTask, ptrdiff_t *C, ptrdiff_t *myT_CLT, ptrdiff_t *myT_CLE, ptrdiff_t *myT_C, MPI_Comm comm)
static void _destroy_segment_group(struct SegmentGroupDescr *descr)
void mpsort_mpi_unset_options(int options)
static void _bisect_radix_le(void *r, const void *r1, const void *r2, size_t rsize)
void mpsort_free_timers(void)
static int _mpsort_mpi_options
static void piter_init(struct piter *pi, char *Pmin, char *Pmax, int Plength, struct crstruct *d)
static void _bisect_radix_be(void *r, const void *r1, const void *r2, size_t rsize)
void mpsort_mpi_set_options(int options)
static void piter_destroy(struct piter *pi)
static void _setup_mpsort_mpi(struct crmpistruct *o, struct crstruct *d, void *myoutbase, size_t myoutnmemb, MPI_Comm comm)
int mpsort_mpi_find_ntimers(struct TIMERS *timers)
static void _histogram(char *P, int Plength, void *mybase, size_t mynmemb, ptrdiff_t *myCLT, ptrdiff_t *myCLE, struct crstruct *d)
static int mpsort_mpi_histogram_sort(struct crstruct d, struct crmpistruct o)
void mpsort_setup_timers(int ntimers)
static void _destroy_mpsort_mpi(struct crmpistruct *o)
static void _bisect_radix(void *r, const void *r1, const void *r2, size_t rsize, int dir)
static struct crstruct _cacr_d
static void mpsort_increment_timer(const char *name, int erase)
static void MPIU_Scatter(MPI_Comm comm, int root, const void *sendbuffer, void *recvbuffer, int nrecv, size_t elsize, int *totalnsend)
static int _assign_colors(size_t glocalsize, size_t *sizes, int *ncolor, MPI_Comm comm)
static ptrdiff_t _bsearch_last_lt(void *P, void *base, size_t nmemb, struct crstruct *d)
static void piter_bisect(struct piter *pi, char *P)
void _setup_radix_sort(struct crstruct *d, void *base, size_t nmemb, size_t size, void(*radix)(const void *ptr, void *radix, void *arg), size_t rsize, void *arg)
static int _compute_and_compar_radix(const void *p1, const void *p2)
static MPI_Datatype MPI_TYPE_PTRDIFF
static int _compar_radix(const void *r1, const void *r2, size_t rsize, int dir)
static int piter_all_done(struct piter *pi)
void mpsort_mpi_report_last_run()
void mpsort_mpi_newarray_impl(void *mybase, size_t mynmemb, void *myoutbase, size_t myoutnmemb, size_t elsize, void(*radix)(const void *ptr, void *radix, void *arg), size_t rsize, void *arg, MPI_Comm comm, const int line, const char *file)
int(* _compar_fn_t)(const void *r1, const void *r2, size_t rsize)
static int _compar_radix_be(const void *r1, const void *r2, size_t rsize)
#define MPSORT_REQUIRE_GATHER_SORT
#define MPSORT_DISABLE_GATHER_SORT
#define mymalloc(name, size)
#define ta_malloc(name, type, nele)
#define mymalloc2(name, size)
MPI_Datatype MPI_TYPE_DATA
MPI_Datatype MPI_TYPE_RADIX
void(* radix)(const void *ptr, void *radix, void *arg)
int MPI_Alltoallv_smart(void *sendbuf, int *sendcnts, int *sdispls, MPI_Datatype sendtype, void *recvbuf, int *recvcnts, int *rdispls, MPI_Datatype recvtype, MPI_Comm comm)