17#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
18#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
26#ifndef VQSORT_SECURE_RNG
27#define VQSORT_SECURE_RNG 0
31#include "third_party/absl/random/random.h"
41#include <sanitizer/msan_interface.h>
47#if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \
48 defined(HWY_TARGET_TOGGLE)
49#ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
50#undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
52#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
75 __msan_unpoison(p, bytes);
84 size_t start = 0,
size_t max_lanes = 16) {
86 Print(
d, label,
v, start, max_lanes);
98template <
class Traits,
typename T>
101 constexpr size_t N1 = st.LanesPerKey();
104 while (start < num_lanes) {
105 const size_t left = 2 * start + N1;
106 const size_t right = 2 * start + 2 * N1;
107 if (left >= num_lanes)
break;
108 size_t idx_larger = start;
109 const auto key_j = st.SetKey(
d, lanes + start);
110 if (
AllTrue(
d, st.Compare(
d, key_j, st.SetKey(
d, lanes + left)))) {
113 if (right < num_lanes &&
114 AllTrue(
d, st.Compare(
d, st.SetKey(
d, lanes + idx_larger),
115 st.SetKey(
d, lanes + right)))) {
118 if (idx_larger == start)
break;
119 st.Swap(lanes + start, lanes + idx_larger);
126template <
class Traits,
typename T>
128 constexpr size_t N1 = st.LanesPerKey();
130 if (num_lanes < 2 * N1)
return;
133 for (
size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) {
137 for (
size_t i = num_lanes - N1; i != 0; i -= N1) {
139 st.Swap(lanes + 0, lanes + i);
146#if VQSORT_ENABLED || HWY_IDE
151template <
class D,
class Traits,
typename T>
156 using V =
decltype(
Zero(
d));
163 const size_t num_pow2 =
size_t{1}
165 static_cast<uint32_t
>(num - 1)));
177 const size_t N_sn =
Lanes(CappedTag<T, Constants::kMaxCols>());
179 SortingNetwork(st, keys, N_sn);
185 for (i = 0; i +
N <= num; i +=
N) {
192 const V kPadding = st.LastValue(
d);
199 SortingNetwork(st, buf, cols);
201 for (i = 0; i +
N <= num; i +=
N) {
212template <
class D,
class Traits,
class T>
213HWY_INLINE size_t PartitionToMultipleOfUnroll(D
d, Traits st,
224 const size_t num_rem =
225 (num < 2 * kUnroll *
N) ? num : (num & (kUnroll *
N - 1));
227 for (; i +
N <= num_rem; i +=
N) {
228 const Vec<D> vL =
LoadU(
d, keys + readL);
231 const auto comp = st.Compare(
d, pivot, vL);
237 const auto mask =
FirstN(
d, num_rem - i);
238 const Vec<D> vL =
LoadU(
d, keys + readL);
240 const auto comp = st.Compare(
d, pivot, vL);
255 memcpy(posL, keys + num, bufR *
sizeof(T));
256 memcpy(keys + num, buf, bufR *
sizeof(T));
257 return static_cast<size_t>(posL - keys);
261V OrXor(
const V o,
const V x1,
const V x2) {
263 return Or(o,
Xor(x1, x2));
268template <
class D,
class Traits,
typename T>
269HWY_INLINE void StoreLeftRight(D
d, Traits st,
const Vec<D>
v,
271 size_t& writeL,
size_t& remaining) {
274 const auto comp = st.Compare(
d, pivot,
v);
285 const auto lr = st.CompressKeys(
v, comp);
290 StoreU(lr,
d, keys + remaining + writeL);
302template <
class D,
class Traits,
typename T>
303HWY_INLINE void StoreLeftRight4(D
d, Traits st,
const Vec<D> v0,
304 const Vec<D> v1,
const Vec<D> v2,
305 const Vec<D> v3,
const Vec<D> pivot,
308 StoreLeftRight(
d, st, v0, pivot, keys, writeL, remaining);
309 StoreLeftRight(
d, st, v1, pivot, keys, writeL, remaining);
310 StoreLeftRight(
d, st, v2, pivot, keys, writeL, remaining);
311 StoreLeftRight(
d, st, v3, pivot, keys, writeL, remaining);
318template <
class D,
class Traits,
typename T>
321 using V =
decltype(
Zero(
d));
330 const V vlast =
LoadU(
d, keys + last);
332 const size_t consumedL =
333 PartitionToMultipleOfUnroll(
d, st, keys, num, pivot, buf);
371 size_t remaining = num;
381 const V vL0 =
LoadU(
d, readL + 0 *
N);
382 const V vL1 =
LoadU(
d, readL + 1 *
N);
383 const V vL2 =
LoadU(
d, readL + 2 *
N);
384 const V vL3 =
LoadU(
d, readL + 3 *
N);
385 readL += kUnroll *
N;
386 readR -= kUnroll *
N;
387 const V vR0 =
LoadU(
d, readR + 0 *
N);
388 const V vR1 =
LoadU(
d, readR + 1 *
N);
389 const V vR2 =
LoadU(
d, readR + 2 *
N);
390 const V vR3 =
LoadU(
d, readR + 3 *
N);
393 while (readL != readR) {
397 const size_t capacityL =
398 static_cast<size_t>((readL - keys) -
static_cast<ptrdiff_t
>(writeL));
423 if (kUnroll *
N < capacityL) {
424 readR -= kUnroll *
N;
435 readL += kUnroll *
N;
439 StoreLeftRight4(
d, st, v0, v1, v2, v3, pivot, keys, writeL, remaining);
443 StoreLeftRight4(
d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, remaining);
444 StoreLeftRight4(
d, st, vR0, vR1, vR2, vR3, pivot, keys, writeL, remaining);
453 const size_t totalR = last - writeL;
454 const size_t startR = totalR <
N ? writeL + totalR -
N : writeL;
458 const auto comp = st.Compare(
d, pivot, vlast);
462 return consumedL + writeL;
469template <
class D,
class Traits,
typename T>
471 size_t num,
const Vec<D> valueL,
472 const Vec<D> valueR, Vec<D>& third,
481 for (; i +
N <= num; i +=
N) {
482 const Vec<D>
v =
LoadU(
d, keys + i);
486 const Mask<D> eqL = st.EqualKeys(
d,
v, valueL);
487 const Mask<D> eqR = st.EqualKeys(
d,
v, valueR);
499 third = st.SetKey(
d, keys + i + lane);
501 fprintf(stderr,
"found 3rd value at vec %zu; writeL %zu\n", i, writeL);
504 for (; writeL +
N <= i; writeL +=
N) {
505 StoreU(valueR,
d, keys + writeL);
510 StoreU(valueL,
d, keys + writeL);
515 const size_t remaining = num - i;
517 const Vec<D>
v =
Load(
d, buf);
518 const Mask<D> valid =
FirstN(
d, remaining);
519 const Mask<D> eqL =
And(st.EqualKeys(
d,
v, valueL), valid);
520 const Mask<D> eqR = st.EqualKeys(
d,
v, valueR);
522 const Mask<D> eq =
Or(
Or(eqL, eqR),
Not(valid));
526 third = st.SetKey(
d, keys + i + lane);
528 fprintf(stderr,
"found 3rd value at partial vec %zu; writeL %zu\n", i,
532 for (; writeL +
N <= i; writeL +=
N) {
533 StoreU(valueR,
d, keys + writeL);
543 for (; i +
N <= num; i +=
N) {
549 fprintf(stderr,
"Successful MaybePartitionTwoValue\n");
555template <
class D,
class Traits,
typename T>
557 size_t num,
const Vec<D> valueL,
558 const Vec<D> valueR, Vec<D>& third,
563 size_t pos = num -
N;
569 for (; pos < num; pos -=
N) {
570 const Vec<D>
v =
LoadU(
d, keys + pos);
574 const Mask<D> eqL = st.EqualKeys(
d,
v, valueL);
575 const Mask<D> eqR = st.EqualKeys(
d,
v, valueR);
581 third = st.SetKey(
d, keys + pos + lane);
583 fprintf(stderr,
"found 3rd value at vec %zu; countR %zu\n", pos,
591 const size_t endL = num - countR;
592 for (; pos +
N <= endL; pos +=
N) {
603 const size_t remaining = pos +
N;
605 const Vec<D>
v =
LoadU(
d, keys);
606 const Mask<D> valid =
FirstN(
d, remaining);
607 const Mask<D> eqL = st.EqualKeys(
d,
v, valueL);
608 const Mask<D> eqR =
And(st.EqualKeys(
d,
v, valueR), valid);
610 const Mask<D> eq =
Or(
Or(eqL, eqR),
Not(valid));
614 third = st.SetKey(
d, keys + lane);
616 fprintf(stderr,
"found 3rd value at partial vec %zu; writeR %zu\n", pos,
624 const size_t endL = num - countR;
625 for (; pos +
N <= endL; pos +=
N) {
638 const size_t endL = num - countR;
640 for (; i +
N <= endL; i +=
N) {
648 "MaybePartitionTwoValueR countR %zu pos %zu i %zu endL %zu\n",
649 countR, pos, i, endL);
659template <
class D,
class Traits,
typename T>
660HWY_INLINE bool PartitionIfTwoKeys(D
d, Traits st,
const Vec<D> pivot,
662 const size_t idx_second,
const Vec<D> second,
665 const bool is_pivotR =
AllFalse(
d, st.Compare(
d, pivot, second));
667 fprintf(stderr,
"Samples all equal, diff at %zu, isPivotR %d\n", idx_second,
674 return is_pivotR ? MaybePartitionTwoValueR(
d, st, keys, num, second, pivot,
676 : MaybePartitionTwoValue(
d, st, keys + idx_second,
677 num - idx_second, pivot, second,
683template <
class D,
class Traits,
typename T>
686 constexpr size_t kSampleLanes = 3 * 64 /
sizeof(T);
687 constexpr size_t N1 = st.LanesPerKey();
688 const Vec<D> valueL = st.SetKey(
d, samples);
689 const Vec<D> valueR = st.SetKey(
d, samples + kSampleLanes - N1);
692 const Vec<D> prev = st.PrevValue(
d, valueR);
703 return MaybePartitionTwoValue(
d, st, keys, num, valueL, valueR, third, buf);
708template <
class Traits,
class V>
709HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) {
715 const auto sum =
Xor(
Xor(v0, v1), v2);
716 const auto first = st.First(
d, st.First(
d, v0, v1), v2);
717 const auto last = st.Last(
d, st.Last(
d, v0, v1), v2);
718 return Xor(
Xor(sum, first), last);
721 v1 = st.Last(
d, v0, v1);
722 v1 = st.First(
d, v1, v2);
727using Generator = absl::BitGen;
733 Generator(
const void* heap,
size_t num) {
738 explicit Generator(uint64_t seed) {
743 uint64_t operator()() {
744 const uint64_t b = b_;
746 const uint64_t next = a_ ^ w_;
747 a_ = (b + (b << 3)) ^ (b >> 11);
748 const uint64_t rot = (b << 24) | (b >> 40);
765HWY_INLINE size_t RandomChunkIndex(
const uint32_t num_chunks, uint32_t bits) {
766 const uint64_t chunk_index = (
static_cast<uint64_t
>(bits) * num_chunks) >> 32;
768 return static_cast<size_t>(chunk_index);
772template <
class D,
class Traits,
typename T>
775 using V =
decltype(
Zero(
d));
784 const size_t misalign =
785 (
reinterpret_cast<uintptr_t
>(keys) /
sizeof(T)) & (kLanesPerChunk - 1);
787 const size_t consume = kLanesPerChunk - misalign;
793 uint64_t* bits64 =
reinterpret_cast<uint64_t*
>(buf);
794 for (
size_t i = 0; i < 5; ++i) {
797 const uint32_t* bits =
reinterpret_cast<const uint32_t*
>(buf);
799 const size_t num_chunks64 = num / kLanesPerChunk;
801 const uint32_t num_chunks =
802 static_cast<uint32_t
>(
HWY_MIN(num_chunks64, 0xFFFFFFFFull));
804 const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) * kLanesPerChunk;
805 const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) * kLanesPerChunk;
806 const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) * kLanesPerChunk;
807 const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) * kLanesPerChunk;
808 const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) * kLanesPerChunk;
809 const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) * kLanesPerChunk;
810 const size_t offset6 = RandomChunkIndex(num_chunks, bits[6]) * kLanesPerChunk;
811 const size_t offset7 = RandomChunkIndex(num_chunks, bits[7]) * kLanesPerChunk;
812 const size_t offset8 = RandomChunkIndex(num_chunks, bits[8]) * kLanesPerChunk;
813 for (
size_t i = 0; i < kLanesPerChunk; i +=
N) {
814 const V v0 =
Load(
d, keys + offset0 + i);
815 const V v1 =
Load(
d, keys + offset1 + i);
816 const V v2 =
Load(
d, keys + offset2 + i);
817 const V medians0 = MedianOf3(st, v0, v1, v2);
818 Store(medians0,
d, buf + i);
820 const V v3 =
Load(
d, keys + offset3 + i);
821 const V v4 =
Load(
d, keys + offset4 + i);
822 const V v5 =
Load(
d, keys + offset5 + i);
823 const V medians1 = MedianOf3(st, v3, v4, v5);
824 Store(medians1,
d, buf + i + kLanesPerChunk);
826 const V v6 =
Load(
d, keys + offset6 + i);
827 const V v7 =
Load(
d, keys + offset7 + i);
828 const V v8 =
Load(
d, keys + offset8 + i);
829 const V medians2 = MedianOf3(st, v6, v7, v8);
830 Store(medians2,
d, buf + i + kLanesPerChunk * 2);
835template <
class D,
class Traits>
838 constexpr size_t kSampleLanes = 3 * 64 /
sizeof(TFromD<D>);
842 const V first = st.SetKey(
d, samples);
846 for (; i +
N <= kSampleLanes; i +=
N) {
847 const V
v =
Load(
d, samples + i);
848 diff = OrXor(diff, first,
v);
851 const V
v =
Load(
d, samples + i);
852 const auto valid =
FirstN(
d, kSampleLanes - i);
853 diff =
IfThenElse(valid, OrXor(diff, first,
v), diff);
855 return st.NoKeyDifference(
d, diff);
858template <
class D,
class Traits,
typename T>
861 constexpr size_t kSampleLanes = 3 * 64 /
sizeof(T);
863 const size_t N128 =
Lanes(d128);
866 static_assert(192 <= kBytes,
"");
868 const auto kPadding = st.LastValue(d128);
871 for (
size_t i = kSampleLanes; i <= kBytes /
sizeof(T); i += N128) {
872 StoreU(kPadding, d128, buf + i);
875 SortingNetwork(st, buf, kCols);
879 fprintf(stderr,
"Samples:\n");
880 for (
size_t i = 0; i < kSampleLanes; i +=
N) {
888enum class PivotResult {
895HWY_INLINE const char* PivotResultString(PivotResult result) {
897 case PivotResult::kDone:
899 case PivotResult::kNormal:
901 case PivotResult::kIsFirst:
903 case PivotResult::kWasLast:
909template <
class Traits,
typename T>
911 constexpr size_t kSampleLanes = 3 * 64 /
sizeof(T);
912 constexpr size_t N1 = st.LanesPerKey();
914 constexpr size_t kRankMid = kSampleLanes / 2;
915 static_assert(kRankMid % N1 == 0,
"Mid is not an aligned key");
918 size_t rank_prev = kRankMid - N1;
919 for (; st.Equal1(samples + rank_prev, samples + kRankMid); rank_prev -= N1) {
921 if (rank_prev == 0)
return 0;
924 size_t rank_next = rank_prev + N1;
925 for (; st.Equal1(samples + rank_next, samples + kRankMid); rank_next += N1) {
928 if (rank_next == kSampleLanes - N1)
return rank_prev;
936 const size_t excess_if_median = rank_next - kRankMid;
937 const size_t excess_if_prev = kRankMid - rank_prev;
938 return excess_if_median < excess_if_prev ? kRankMid : rank_prev;
943template <
class D,
class Traits,
typename T>
946 const size_t pivot_rank = PivotRank(st, samples);
947 const Vec<D> pivot = st.SetKey(
d, samples + pivot_rank);
949 fprintf(stderr,
" Pivot rank %zu = %f\n", pivot_rank,
950 static_cast<double>(
GetLane(pivot)));
953 constexpr size_t kSampleLanes = 3 * 64 /
sizeof(T);
954 constexpr size_t N1 = st.LanesPerKey();
955 const Vec<D> last = st.SetKey(
d, samples + kSampleLanes - N1);
956 const bool all_neq =
AllTrue(
d, st.NotEqualKeys(
d, pivot, last));
964template <
class D,
class Traits,
typename T>
965HWY_INLINE bool AllEqual(D
d, Traits st,
const Vec<D> pivot,
971 const Vec<D> zero =
Zero(
d);
974 const size_t misalign =
975 (
reinterpret_cast<uintptr_t
>(keys) /
sizeof(T)) & (
N - 1);
977 const size_t consume =
N - misalign;
979 const Vec<D>
v =
LoadU(
d, keys);
981 const Mask<D> diff =
And(
FirstN(
d, consume), st.NotEqualKeys(
d,
v, pivot));
984 *first_mismatch = lane;
989 HWY_DASSERT(((
reinterpret_cast<uintptr_t
>(keys + i) /
sizeof(T)) & (
N - 1)) ==
1001 constexpr size_t kLoops = 8;
1002 const size_t lanes_per_group = kLoops * 2 *
N;
1004 for (; i + lanes_per_group <= num; i += lanes_per_group) {
1006 for (
size_t loop = 0; loop < kLoops; ++loop) {
1007 const Vec<D> v0 =
Load(
d, keys + i + loop * 2 *
N);
1008 const Vec<D> v1 =
Load(
d, keys + i + loop * 2 *
N +
N);
1009 diff0 = OrXor(diff0, v0, pivot);
1010 diff1 = OrXor(diff1, v1, pivot);
1017 const Vec<D>
v =
Load(
d, keys + i);
1018 const Mask<D> diff = st.NotEqualKeys(
d,
v, pivot);
1021 *first_mismatch = i + lane;
1029 for (; i +
N <= num; i +=
N) {
1030 const Vec<D>
v =
Load(
d, keys + i);
1031 const Mask<D> diff = st.NotEqualKeys(
d,
v, pivot);
1034 *first_mismatch = i + lane;
1040 const Vec<D>
v =
LoadU(
d, keys + i);
1041 const Mask<D> diff = st.NotEqualKeys(
d,
v, pivot);
1044 *first_mismatch = i + lane;
1049 fprintf(stderr,
"All keys equal\n");
1055template <
class D,
class Traits,
typename T>
1057 size_t num,
const Vec<D> pivot) {
1062 fprintf(stderr,
"Scanning for before\n");
1067 constexpr size_t kLoops = 16;
1068 const size_t lanes_per_group = kLoops *
N;
1070 Vec<D> first = pivot;
1073 for (; i + lanes_per_group <= num; i += lanes_per_group) {
1075 for (
size_t loop = 0; loop < kLoops; ++loop) {
1076 const Vec<D> curr =
LoadU(
d, keys + i + loop *
N);
1077 first = st.First(
d, first, curr);
1082 fprintf(stderr,
"Stopped scanning at end of group %zu\n",
1083 i + lanes_per_group);
1089 for (; i +
N <= num; i +=
N) {
1090 const Vec<D> curr =
LoadU(
d, keys + i);
1093 fprintf(stderr,
"Stopped scanning at %zu\n", i);
1100 const Vec<D> curr =
LoadU(
d, keys + num -
N);
1103 fprintf(stderr,
"Stopped scanning at last %zu\n", num -
N);
1113template <
class D,
class Traits,
typename T>
1115 size_t num,
const Vec<D> pivot) {
1120 fprintf(stderr,
"Scanning for after\n");
1125 constexpr size_t kLoops = 16;
1126 const size_t lanes_per_group = kLoops *
N;
1128 Vec<D> last = pivot;
1131 for (; i + lanes_per_group <= num; i += lanes_per_group) {
1133 for (
size_t loop = 0; loop < kLoops; ++loop) {
1134 const Vec<D> curr =
LoadU(
d, keys + i + loop *
N);
1135 last = st.Last(
d, last, curr);
1140 fprintf(stderr,
"Stopped scanning at end of group %zu\n",
1141 i + lanes_per_group);
1147 for (; i +
N <= num; i +=
N) {
1148 const Vec<D> curr =
LoadU(
d, keys + i);
1151 fprintf(stderr,
"Stopped scanning at %zu\n", i);
1158 const Vec<D> curr =
LoadU(
d, keys + num -
N);
1161 fprintf(stderr,
"Stopped scanning at last %zu\n", num -
N);
1172template <
class D,
class Traits,
typename T>
1173HWY_INLINE Vec<D> ChoosePivotForEqualSamples(D
d, Traits st,
1176 Vec<D> second, Vec<D> third,
1177 PivotResult& result) {
1178 const Vec<D> pivot = st.SetKey(
d, samples);
1182 result = PivotResult::kIsFirst;
1186 result = PivotResult::kWasLast;
1187 return st.PrevValue(
d, pivot);
1194 const bool before = !
AllFalse(
d, st.Compare(
d, second, pivot));
1197 if (
HWY_UNLIKELY(ExistsAnyAfter(
d, st, keys, num, pivot))) {
1198 result = PivotResult::kNormal;
1206 result = PivotResult::kWasLast;
1207 return st.PrevValue(
d, pivot);
1211 if (
HWY_UNLIKELY(ExistsAnyBefore(
d, st, keys, num, pivot))) {
1212 result = PivotResult::kNormal;
1219 st.Sort2(
d, second, third);
1221 const bool before = !
AllFalse(
d, st.Compare(
d, second, pivot));
1222 const bool after = !
AllFalse(
d, st.Compare(
d, pivot, third));
1229 if (
HWY_UNLIKELY(after || ExistsAnyAfter(
d, st, keys, num, pivot))) {
1230 result = PivotResult::kNormal;
1238 result = PivotResult::kWasLast;
1239 return st.PrevValue(
d, pivot);
1243 if (
HWY_UNLIKELY(ExistsAnyBefore(
d, st, keys, num, pivot))) {
1244 result = PivotResult::kNormal;
1254 result = PivotResult::kIsFirst;
1260template <
class D,
class Traits,
typename T>
1265 if (num <
N)
return;
1267 Vec<D> first = st.LastValue(
d);
1268 Vec<D> last = st.FirstValue(
d);
1271 for (; i +
N <= num; i +=
N) {
1272 const Vec<D>
v =
LoadU(
d, keys + i);
1273 first = st.First(
d,
v, first);
1274 last = st.Last(
d,
v, last);
1278 const Vec<D>
v =
LoadU(
d, keys + num -
N);
1279 first = st.First(
d,
v, first);
1280 last = st.Last(
d,
v, last);
1283 first = st.FirstOfLanes(
d, first, buf);
1284 last = st.LastOfLanes(
d, last, buf);
1292template <
class D,
class Traits,
typename T>
1296 const size_t remaining_levels) {
1300 BaseCase(
d, st, keys, keys_end, num, buf);
1306 fprintf(stderr,
"\n\n=== Recurse depth=%zu len=%zu\n", remaining_levels,
1308 PrintMinMax(
d, st, keys, num, buf);
1311 DrawSamples(
d, st, keys, num, buf, rng);
1314 PivotResult result = PivotResult::kNormal;
1316 pivot = st.SetKey(
d, buf);
1317 size_t idx_second = 0;
1318 if (
HWY_UNLIKELY(AllEqual(
d, st, pivot, keys, num, &idx_second))) {
1323 const Vec<D> second = st.SetKey(
d, keys + idx_second);
1331 PartitionIfTwoKeys(
d, st, pivot, keys, num, idx_second,
1332 second, third, buf))) {
1338 pivot = ChoosePivotForEqualSamples(
d, st, keys, num, buf, second, third,
1344 SortSamples(
d, st, buf);
1349 PartitionIfTwoSamples(
d, st, keys, num, buf))) {
1353 pivot = ChoosePivotByRank(
d, st, buf);
1360 fprintf(stderr,
"HeapSort reached, size=%zu\n", num);
1366 const size_t bound = Partition(
d, st, keys, num, pivot, buf);
1368 fprintf(stderr,
"bound %zu num %zu result %s\n", bound, num,
1369 PivotResultString(result));
1379 if (
HWY_LIKELY(result != PivotResult::kIsFirst)) {
1380 Recurse(
d, st, keys, keys_end, bound, buf, rng, remaining_levels - 1);
1382 if (
HWY_LIKELY(result != PivotResult::kWasLast)) {
1383 Recurse(
d, st, keys + bound, keys_end, num - bound, buf, rng,
1384 remaining_levels - 1);
1389template <
class D,
class Traits,
typename T>
1398 const bool partial_128 = !
IsFull(
d) &&
N < 2 && st.Is128();
1403 constexpr bool kPotentiallyHuge =
1405 const bool huge_vec = kPotentiallyHuge && (2 *
N > base_case_num);
1406 if (partial_128 || huge_vec) {
1408 fprintf(stderr,
"WARNING: using slow HeapSort: partial %d huge %d\n",
1409 partial_128, huge_vec);
1437template <
class D,
class Traits,
typename T>
1441 fprintf(stderr,
"=============== Sort num %zu\n", num);
1444#if VQSORT_ENABLED || HWY_IDE
1445#if !HWY_HAVE_SCALABLE
1451 static_assert(
sizeof(storage) <= 8192,
"Unexpectedly large, check size");
1455 if (detail::HandleSpecialCases(
d, st, keys, num))
return;
1457#if HWY_MAX_BYTES > 64
1460 return Sort(
CappedTag<T, 64 /
sizeof(T)>(), st, keys, num, buf);
1464 detail::Generator rng(keys, num);
1468 detail::Recurse(
d, st, keys, keys + num, num, buf, rng, max_levels);
1473 fprintf(stderr,
"WARNING: using slow HeapSort because vqsort disabled\n");
#define HWY_MAX(a, b)
Definition: base.h:135
#define HWY_RESTRICT
Definition: base.h:64
#define HWY_NOINLINE
Definition: base.h:72
#define HWY_MIN(a, b)
Definition: base.h:134
#define HWY_INLINE
Definition: base.h:70
#define HWY_DASSERT(condition)
Definition: base.h:238
#define HWY_DEFAULT_UNROLL
Definition: base.h:146
#define HWY_LIKELY(expr)
Definition: base.h:75
#define HWY_UNLIKELY(expr)
Definition: base.h:76
static void Fill24Bytes(const void *seed_heap, size_t seed_num, void *bytes)
void SiftDown(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes, size_t start)
Definition: vqsort-inl.h:99
HWY_INLINE void MaybePrintVector(D d, const char *label, Vec< D > v, size_t start=0, size_t max_lanes=16)
Definition: vqsort-inl.h:83
HWY_INLINE Mask128< T, N > ExclusiveNeither(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:963
HWY_INLINE bool AllTrue(hwy::SizeTag< 1 >, const Mask128< T > m)
Definition: wasm_128-inl.h:3661
HWY_INLINE Mask128< T, N > And(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:815
HWY_INLINE void UnpoisonIfMemorySanitizer(void *p, size_t bytes)
Definition: vqsort-inl.h:73
void HeapSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes)
Definition: vqsort-inl.h:127
HWY_INLINE bool AllFalse(hwy::SizeTag< 1 >, const Mask256< T > mask)
Definition: x86_256-inl.h:4543
HWY_INLINE Mask128< T, N > Or(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:889
HWY_INLINE Mask128< T, N > AndNot(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:852
HWY_INLINE size_t CountTrue(hwy::SizeTag< 1 >, const Mask128< T > mask)
Definition: arm_neon-inl.h:5609
HWY_INLINE Vec128< T, N > IfThenElse(hwy::SizeTag< 1 >, Mask128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: x86_128-inl.h:670
constexpr bool IsFull(Simd< T, N, kPow2 >)
Definition: ops/shared-inl.h:115
HWY_INLINE Mask512< T > Not(hwy::SizeTag< 1 >, const Mask512< T > m)
Definition: x86_512-inl.h:1613
HWY_INLINE Mask128< T, N > Xor(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:926
d
Definition: rvv-inl.h:1998
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:2456
typename detail::CappedTagChecker< T, kLimit >::type CappedTag
Definition: ops/shared-inl.h:184
void Print(const D d, const char *caption, VecArg< V > v, size_t lane_u=0, size_t max_lanes=7)
Definition: print-inl.h:39
HWY_API void BlendedStore(Vec128< T, N > v, Mask128< T, N > m, Simd< T, N, 0 > d, T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2941
HWY_API constexpr size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:243
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2753
void Sort(D d, Traits st, T *HWY_RESTRICT keys, size_t num, T *HWY_RESTRICT buf)
Definition: vqsort-inl.h:1438
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2772
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2591
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:1020
HWY_API TFromV< V > GetLane(const V v)
Definition: arm_neon-inl.h:1076
typename detail::FixedTagChecker< T, kNumLanes >::type FixedTag
Definition: ops/shared-inl.h:200
HWY_API void SafeCopyN(const size_t num, D d, const T *HWY_RESTRICT from, T *HWY_RESTRICT to)
Definition: generic_ops-inl.h:111
HWY_API size_t CompressStore(Vec128< T, N > v, const Mask128< T, N > mask, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:6248
N
Definition: rvv-inl.h:1998
HWY_API size_t CompressBlendedStore(Vec128< T, N > v, Mask128< T, N > m, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:6257
HWY_API size_t FindKnownFirstTrue(const Simd< T, N, 0 > d, const Mask128< T, N > mask)
Definition: arm_neon-inl.h:5683
HWY_API void Store(Vec128< T, N > v, Simd< T, N, 0 > d, T *HWY_RESTRICT aligned)
Definition: arm_neon-inl.h:2934
const vfloat64m1_t v
Definition: rvv-inl.h:1998
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:40
Definition: aligned_allocator.h:27
HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T *p)
Definition: cache_control.h:77
HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:831
constexpr size_t CeilLog2(TI x)
Definition: base.h:899
#define HWY_MAX_BYTES
Definition: set_macros-inl.h:84
#define HWY_LANES(T)
Definition: set_macros-inl.h:85
#define HWY_ALIGN
Definition: set_macros-inl.h:83
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: arm_neon-inl.h:5729
Definition: contrib/sort/shared-inl.h:28
static constexpr size_t kMaxCols
Definition: contrib/sort/shared-inl.h:34
static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t)
Definition: contrib/sort/shared-inl.h:69
static constexpr size_t kMaxRows
Definition: contrib/sort/shared-inl.h:43
static constexpr HWY_INLINE size_t BaseCaseNum(size_t N)
Definition: contrib/sort/shared-inl.h:45
static constexpr size_t kMaxRowsLog2
Definition: contrib/sort/shared-inl.h:42
static constexpr size_t kPartitionUnroll
Definition: contrib/sort/shared-inl.h:54
#define VQSORT_PRINT
Definition: vqsort-inl.h:21