diff --git a/inc/zoo/meta/popcount.h b/inc/zoo/meta/popcount.h index 3b00056b..03555908 100644 --- a/inc/zoo/meta/popcount.h +++ b/inc/zoo/meta/popcount.h @@ -2,6 +2,8 @@ #define ZOO_HEADER_META_POPCOUNT_H #include "zoo/meta/BitmaskMaker.h" +#include +#include namespace zoo { namespace meta { @@ -90,6 +92,30 @@ struct PopcountIntrinsic { }; #endif +template +constexpr auto NumBits() { + return sizeof(T) * 8; +} +static_assert(NumBits() == 16); +static_assert(NumBits() == 64); + +template +constexpr +std::enable_if_t && NumBits() <= 64, T> +basic_popcount(T x) { + if constexpr (NumBits() <= 32) { + return __builtin_popcountl(x); + } else { + return __builtin_popcountll(x); + } +} + +static_assert(basic_popcount(0b111) == 3); +static_assert(basic_popcount(0xFF) == 8); +static_assert(basic_popcount(0xFF'FF'FF'FF) == 32); +static_assert(basic_popcount(0xFF'FF'FF'FF'FF'FF'FF'FF) == 64); +static_assert(basic_popcount(0xFF'FF'FF'FF'FF'FF'FF'FF - 2 - 4 - 8) == 61); + }} #endif diff --git a/inc/zoo/swar/SWAR.h b/inc/zoo/swar/SWAR.h index 60ba9540..4344575c 100644 --- a/inc/zoo/swar/SWAR.h +++ b/inc/zoo/swar/SWAR.h @@ -299,7 +299,7 @@ constexpr auto broadcast(SWAR v) { /// BooleanSWAR treats the MSB of each SWAR lane as the boolean associated with that lane. template -struct BooleanSWAR: SWAR { +struct BooleanSWAR : SWAR { using Base = SWAR; template> @@ -308,8 +308,13 @@ struct BooleanSWAR: SWAR { { this->m_v <<= (NBits - 1); } // Booleanness is stored in the MSBs - static constexpr auto MaskMSB = - broadcast(Base(T(1) << (NBits -1))); + static constexpr auto MaskMSB = []{ + if constexpr (SWAR::Lanes == 1) { + return Base(T{~0}); // all on, no lanes + } + return broadcast(Base(T(1) << (NBits - 1))); + }(); + static constexpr auto AllTrue = MaskMSB; static constexpr auto MaskLSB = broadcast(Base(T(1))); @@ -392,6 +397,11 @@ BooleanSWAR( const bool (&values)[BooleanSWAR::Lanes] ) -> BooleanSWAR; +template +BooleanSWAR( + SWAR arg +) -> BooleanSWAR; + template constexpr BooleanSWAR convertToBooleanSWAR(SWAR arg) noexcept { @@ -586,6 +596,10 @@ constexpr SWAR logarithmFloor(SWAR v) noexcept { return SWAR{popcounts - ones}; } + + + + static_assert( logarithmFloor(SWAR<8>{0x8040201008040201ull}).value() == 0x0706050403020100ull diff --git a/inc/zoo/swar/associative_iteration.h b/inc/zoo/swar/associative_iteration.h index a515dc38..f78838f5 100644 --- a/inc/zoo/swar/associative_iteration.h +++ b/inc/zoo/swar/associative_iteration.h @@ -1,12 +1,15 @@ #ifndef ZOO_SWAR_ASSOCIATIVE_ITERATION_H #define ZOO_SWAR_ASSOCIATIVE_ITERATION_H +#include "Operations.h" +#include "zoo/meta/popcount.h" #include "zoo/swar/SWAR.h" //#define ZOO_DEVELOPMENT_DEBUGGING #ifdef ZOO_DEVELOPMENT_DEBUGGING #include + inline std::ostream &binary(std::ostream &out, uint64_t input, int count) { while(count--) { out << (1 & input); @@ -363,7 +366,7 @@ constexpr auto negate(SWAR input) { return fullAddition(~input, Ones).result; } -/// \brief Performs a generalized iterated application of an associative operator to a base +/// \brief Performs a generalized iterated application of an associative operator to a bases /// /// In algebra, the repeated application of an operator to a "base" has different names depending on the /// operator, for example "a + a + a + ... + a" n-times would be called "repeated addition", @@ -386,17 +389,16 @@ constexpr auto negate(SWAR input) { /// \param log2Count is to potentially reduce the number of iterations if the caller a-priori knows /// there are fewer iterations than what the type of exponent would allow template< - typename Base, typename IterationCount, typename Operator, - // the critical use of associativity is that it allows halving the - // iteration count - typename CountHalver + typename Base, typename IterationCount, + typename Operator, typename CountHalver > constexpr auto associativeOperatorIterated_regressive( - Base base, Base neutral, IterationCount count, IterationCount forSquaring, - Operator op, unsigned log2Count, CountHalver ch + Base base, Base neutral, IterationCount count, + IterationCount forSquaring, Operator op, + unsigned log2Count, CountHalver ch ) { - auto result = neutral; - if(!log2Count) { return result; } + auto result = neutral; // sum = 0 + if(!log2Count) { return result; } // NBits per lane for(;;) { result = op(result, base, count); if(!--log2Count) { break; } @@ -406,6 +408,55 @@ constexpr auto associativeOperatorIterated_regressive( return result; } +namespace count_halving { + +constexpr auto ConsumeMsb = [](auto counts) { + return counts << 1; +}; + +constexpr auto ConsumeLsb = [](auto counts) { + return counts << 1; +}; + +template +constexpr auto ConsumeMsbLaneWise = [](auto counts) { + auto msbCleared = counts & ~S{S::MostSignificantBit}; + return S{msbCleared.value() << 1}; +}; + + +} + +namespace associative_iteration { + +template< + auto Operator, + auto CountHalver, + typename Base, + typename IterationCount +> +constexpr auto regressive( + Base base, + Base neutral, + IterationCount count, + IterationCount forSquaring, + unsigned log2Count +) { + auto result = neutral; // sum = 0 + if(!log2Count) { return result; } // NBits per lane + for(;;) { + result = Operator(result, base, count); + if(!--log2Count) { break; } + result = Operator(result, result, forSquaring); + count = CountHalver(count); + } + return result; +} + +} + +namespace ai = associative_iteration; + template constexpr auto multiplication_OverflowUnsafe_SpecificBitCount( SWAR multiplicand, SWAR multiplier @@ -417,15 +468,10 @@ constexpr auto multiplication_OverflowUnsafe_SpecificBitCount( return left + (addendums & right); }; - auto halver = [](auto counts) { - auto msbCleared = counts & ~S{S::MostSignificantBit}; - return S{msbCleared.value() << 1}; - }; - auto shifted = S{multiplier.value() << (NB - ActualBits)}; return associativeOperatorIterated_regressive( multiplicand, S{0}, shifted, S{S::MostSignificantBit}, operation, - ActualBits, halver + ActualBits, count_halving::ConsumeMsbLaneWise ); } @@ -475,7 +521,7 @@ constexpr auto exponentiation_OverflowUnsafe_SpecificBitCount( exponent = S{static_cast(exponent.value() << (NB - ActualBits))}; return associativeOperatorIterated_regressive( x, - S{meta::BitmaskMaker().value}, // neutral is lane wise.. + S{S::LeastSignificantBit}, exponent, S{S::MostSignificantBit}, operation, @@ -509,10 +555,11 @@ constexpr SWAR doublingMask() { } template + constexpr auto doublePrecision(SWAR input) { using S = SWAR; static_assert( - 0 == S::NSlots % 2, + 0 == S::Lanes % 2, "Precision can only be doubled for SWARs of even element count" ); using RV = SWAR; @@ -523,6 +570,25 @@ constexpr auto doublePrecision(SWAR input) { }; } +template +constexpr +std::enable_if_t= 2 && (S::Lanes % 2) == 0, typename S::type> +horizontalSum_lanes(S s) { + using STwiceWider = SWAR; + constexpr auto + Ones = STwiceWider::LeastSignificantBit, + ShiftBackAmount = STwiceWider::NBits * (STwiceWider::Lanes - 1); + + constexpr auto sum = [](auto a) { + return (a.value() * Ones) >> ShiftBackAmount; + }; + + auto [even, odd] = doublePrecision(s); + return sum(even) + sum(odd); +} + +static_assert(horizontalSum_lanes(SWAR { Literals<32, u64>, {2, 1} }) == 3, "Test failed"); + template constexpr auto halvePrecision(SWAR even, SWAR odd) { using S = SWAR; @@ -535,6 +601,149 @@ constexpr auto halvePrecision(SWAR even, SWAR odd) { return evenHalf | oddHalf; } +template +auto multiply_and_double_p(S a, S b) { + auto product = a * b; + return doublePrecision(product); } +template +constexpr auto horizontalSum_bits(S input) { + constexpr auto + MSBs = S::MostSignificantBit, + Neutral = typename S::type {0}, + ForSquaring = Neutral, + Base = Neutral, + Log2Count = S::NBits; + + auto count = input.value(); + + constexpr auto Operation = [](auto result, auto base, auto count) { + auto msb_masked = count & MSBs; + auto popcount = meta::basic_popcount(msb_masked); + result += popcount + base; + return result; + }; + + return ai::regressive ( + Base, + Neutral, + count, + ForSquaring, + Log2Count + ); +} + +template +constexpr auto is_odd(T input) { + return input & T{1}; +} + +template +constexpr auto horizontalSum(S input) { + if constexpr (is_odd(S::NBits)) { + return horizontalSum_bits(input); + } else { + return horizontalSum_lanes(input); + } +} + +namespace experimental { + +template +constexpr auto horizontalSum_prog(S x) { + constexpr auto Ones = S::LeastSignificantBit, + NBits = S::NBits, + InitialSquare = typename S::type { 1 }; + + auto sum = 0; + auto square = InitialSquare; + auto value = x.value(); + + for (int i = 0; i < NBits; i++) { + auto msb_masked = value & Ones; + auto popcount = zoo::meta::basic_popcount(msb_masked); + auto value_at_square = popcount * square; + sum += value_at_square; + square <<= 1; + value >>= 1; + } + + return sum; +} + +template +constexpr auto horizontalSum_reg(S x) { + constexpr auto MSBs = S::MostSignificantBit, + NBits = S::NBits, + Neutral = typename S::type {0}; + auto result = Neutral; + auto count = x.value(); + + auto operation = [](auto result, auto count) { + auto msb_masked = count & MSBs; + auto popcount = zoo::meta::basic_popcount(msb_masked); + result <<= 1; + result += popcount; + return result; + }; + + for (auto log2Count = NBits;;) { + result = operation(result, count); + if (!--log2Count) { break; } + count = count_halving::ConsumeMsb(count); + } + + return result; +} + +} // namespace experimental + + +#define ZOO_PP_UNPARENTHESIZE(...) __VA_ARGS__ +#define Y(fn, TYPE, values, expected) \ + static_assert(fn( \ + SWAR { \ + Literals, \ + {ZOO_PP_UNPARENTHESIZE values} \ + }) == \ + expected \ + ); + +#define HORIZONTAL_SUM_TESTS(fn) \ + Y(fn, (32, u64), (2, 1), 3) \ + Y(fn, (8, u32), (255, 255, 255, 255), 1020) \ + Y(fn, (8, u32), (255, 254, 255, 255), 1019) \ + Y(fn, (8, u32), (255, 255, 255, 255), 1020) \ + Y(fn, (8, u64), (255, 255, 255, 255, 255, 255, 255, 255), 2040) \ + Y(fn, (4, u64), (15, 15, 15, 15, 15, 15, 15, 15, \ + 15, 15, 15, 15, 15, 15, 15, 15), (15 * 16)) \ + Y(fn, (31, u64), (1, 2), 3) \ + Y(fn, (15, u32), (1, 1), 2) \ + Y(fn, (11, u32), (1, 1), 2) \ + Y(fn, (5, u32), (1, 2, 3, 4, 5, 6), 21) \ + Y(fn, (5, u32), (6, 5, 4, 3, 2, 1), 21) \ + + +#define HORIZONTAL_SUM_TESTS_ALL \ + HORIZONTAL_SUM_TESTS(horizontalSum) \ + HORIZONTAL_SUM_TESTS(horizontalSum_bits) \ + // HORIZONTAL_SUM_TESTS(horizontalSum_lanes) /* doesn't work in all by itself */ \ + HORIZONTAL_SUM_TESTS(experimental::horizontalSum_prog) \ + HORIZONTAL_SUM_TESTS(experimental::horizontalSum_reg) + +HORIZONTAL_SUM_TESTS_ALL + +#undef X +#undef Y +#undef HORIZONTAL_SUM_TESTS +#undef HORIZONTAL_SUM_TESTS_ALL + + +static_assert(((0x01'01 * 0x05'01) & 0xFF'00) == 0x06'00, "Test failed"); + +} + + #endif + diff --git a/pokerbotic/inc/ep/core/SWAR.h b/pokerbotic/inc/ep/core/SWAR.h index 5154dc1d..598cbd0f 100644 --- a/pokerbotic/inc/ep/core/SWAR.h +++ b/pokerbotic/inc/ep/core/SWAR.h @@ -104,6 +104,21 @@ static_assert(0x210 == popcount<0>(0x320), ""); static_assert(0x4321 == popcount<1>(0xF754), ""); static_assert(0x50004 == popcount<3>(0x3E001122), ""); +static_assert(4 == popcount<1>(0b1111)); +static_assert(3 == popcount<1>(0b1011)); + +static_assert(3 == popcount<1>(0b1011)); +static_assert(8 == popcount<2>(0xFF)); +static_assert(16 == popcount<3>(0xFF'FF)); +static_assert(24 == popcount<4>(0xFF'FF'FF)); +static_assert(32 == popcount<4>(0xFF'FF'FF'FF)); +static_assert(40 == popcount<5>(0xFF'FF'FF'FF'FF)); +static_assert(48 == popcount<5>(0xFF'FF'FF'FF'FF'FF)); +static_assert(55 == popcount<5>(0xFF'FF'FF'FF'FF'FF'FF - 8)); + +// todo eduardo why is this broken? +// static_assert(64 == popcount<6>(0xFF'FF'FF'FF'FF'FF'FF'FF)); + template constexpr typename std::make_unsigned::type msb(T v) { return 8*sizeof(T) - 1 - __builtin_clzll(v); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b4b4050e..ffb0691b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -59,7 +59,7 @@ if(MSVC) ${CMAKE_BINARY_DIR}/temporary SOURCES ${CMAKE_SOURCE_DIR}/../compiler_bugs/msvc/sfinae.cpp - CMAKE_FLAGS "-DCMAKE_CXX_STANDARD=17" + CMAKE_FLAGS "-DCMAKE_CXX_STANDARD=20" COMPILE_DEFINITIONS -DTRIGGER_MSVC_SFINAE_BUG OUTPUT_VARIABLE RESULT @@ -84,7 +84,7 @@ if(MSVC) endif() else() # Non-MSVC specific configuration (original content) - set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_FLAGS_UBSAN "-fsanitize=undefined -fno-omit-frame-pointer -fno-optimize-sibling-calls -O1 -g") set(CMAKE_CXX_FLAGS_ASAN "-fsanitize=address -fno-omit-frame-pointer") diff --git a/test/swar/BasicOperations.cpp b/test/swar/BasicOperations.cpp index 602384ae..c4b80c7f 100644 --- a/test/swar/BasicOperations.cpp +++ b/test/swar/BasicOperations.cpp @@ -52,6 +52,11 @@ static_assert(\ expected\ ); +static_assert(SWAR{Literals<16, u32>, {1, 2}}.value() == 0x0001'0002); +static_assert(SWAR<5, u32>::Lanes == 6); +static_assert(SWAR<8, u32>::Lanes == 4); +static_assert(SWAR<9, u32>::Lanes == 3); + /* Preserved to illustrate a technique, remove in a few revisions static_assert(SWAR{Literals<32, u64>, {2, 1}}.value() == 0x00000002'00000001); static_assert(SWAR{Literals<32, u64>, {1, 2}}.value() == 0x00000001'00000002); @@ -120,12 +125,25 @@ static_assert(BS{Literals<4, u16>, {T, F, F, F}}.value() == 0b1000'0000'0000'000 static_assert(SWAR{Literals<8, u16>, {2, 1}}.value() == 0x0201); static_assert(SWAR{Literals<8, u16>, {1, 2}}.value() == 0x0102); */ + +static_assert(SWAR{Literals<5, u32>, {1, 1, 1, 1, 1, 1}}.value() == 0b00001'00001'00001'00001'00001'00001); + #define LITERALS_TESTS \ X(\ (32, u64),\ (2, 1),\ 0x00000002'00000001\ );\ +X(\ + (5, u32),\ + (1, 1, 1, 1, 1, 1),\ + 0b00001'00001'00001'00001'00001'00001\ +);\ +X(\ + (8, u32),\ + (255, 255, 255, 255),\ + 0xFF'FF'FF'FF\ +);\ X(\ (32, u64),\ (1, 2),\ @@ -182,6 +200,7 @@ X(\ 0x12\ ) + LITERALS_TESTS