#pragma once #include #include #include #include #include #include #include #include #include #include #include #include /** This is simple, not numerically stable * implementations of variance/covariance/correlation functions. * * It is about two times faster than stable variants. * Numerical errors may occur during summation. * * This implementation is selected as default, * because "you don't pay for what you don't need" principle. * * For more sophisticated implementation, look at AggregateFunctionStatistics.h */ namespace DB { struct Settings; enum class StatisticsFunctionKind : uint8_t { varPop, varSamp, stddevPop, stddevSamp, skewPop, skewSamp, kurtPop, kurtSamp, covarPop, covarSamp, corr }; template struct StatFuncOneArg { using Type1 = T; using Type2 = T; using ResultType = std::conditional_t, Float32, Float64>; using Data = VarMoments; static constexpr UInt32 num_args = 1; }; template typename Moments> struct StatFuncTwoArg { using Type1 = T1; using Type2 = T2; using ResultType = std::conditional_t && std::is_same_v, Float32, Float64>; using Data = Moments; static constexpr UInt32 num_args = 2; }; template class AggregateFunctionVarianceSimple final : public IAggregateFunctionDataHelper> { public: using T1 = typename StatFunc::Type1; using T2 = typename StatFunc::Type2; using ColVecT1 = ColumnVectorOrDecimal; using ColVecT2 = ColumnVectorOrDecimal; using ResultType = typename StatFunc::ResultType; using ColVecResult = ColumnVector; explicit AggregateFunctionVarianceSimple(const DataTypes & argument_types_, StatisticsFunctionKind kind_) : IAggregateFunctionDataHelper>(argument_types_, {}, std::make_shared>()) , src_scale(0), kind(kind_) { chassert(!argument_types_.empty()); if (isDecimal(argument_types_.front())) src_scale = getDecimalScale(*argument_types_.front()); } String getName() const override { return String(magic_enum::enum_name(kind)); } bool allocatesMemoryInArena() const override { return false; } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { if constexpr (StatFunc::num_args == 2) this->data(place).add( static_cast(static_cast(*columns[0]).getData()[row_num]), static_cast(static_cast(*columns[1]).getData()[row_num])); else { if constexpr (is_decimal) { this->data(place).add( convertFromDecimal, DataTypeFloat64>( static_cast(*columns[0]).getData()[row_num], src_scale)); } else this->data(place).add( static_cast(static_cast(*columns[0]).getData()[row_num])); } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); } void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const override { this->data(place).write(buf); } void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena *) const override { this->data(place).read(buf); } void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override { const auto & data = this->data(place); auto & dst = static_cast(to).getData(); switch (kind) { case StatisticsFunctionKind::varPop: { dst.push_back(data.getPopulation()); break; } case StatisticsFunctionKind::varSamp: { dst.push_back(data.getSample()); break; } case StatisticsFunctionKind::stddevPop: { dst.push_back(sqrt(data.getPopulation())); break; } case StatisticsFunctionKind::stddevSamp: { dst.push_back(sqrt(data.getSample())); break; } case StatisticsFunctionKind::skewPop: { ResultType var_value = data.getPopulation(); if (var_value > 0) dst.push_back(static_cast(data.getMoment3() / pow(var_value, 1.5))); else dst.push_back(std::numeric_limits::quiet_NaN()); break; } case StatisticsFunctionKind::skewSamp: { ResultType var_value = data.getSample(); if (var_value > 0) dst.push_back(static_cast(data.getMoment3() / pow(var_value, 1.5))); else dst.push_back(std::numeric_limits::quiet_NaN()); break; } case StatisticsFunctionKind::kurtPop: { ResultType var_value = data.getPopulation(); if (var_value > 0) dst.push_back(static_cast(data.getMoment4() / pow(var_value, 2))); else dst.push_back(std::numeric_limits::quiet_NaN()); break; } case StatisticsFunctionKind::kurtSamp: { ResultType var_value = data.getSample(); if (var_value > 0) dst.push_back(static_cast(data.getMoment4() / pow(var_value, 2))); else dst.push_back(std::numeric_limits::quiet_NaN()); break; } case StatisticsFunctionKind::covarPop: { dst.push_back(data.getPopulation()); break; } case StatisticsFunctionKind::covarSamp: { dst.push_back(data.getSample()); break; } case StatisticsFunctionKind::corr: { dst.push_back(data.get()); break; } } } private: UInt32 src_scale; StatisticsFunctionKind kind; }; struct Settings; namespace ErrorCodes { extern const int ILLEGAL_TYPE_OF_ARGUMENT; } template