#pragma once #include #include #include #include #include /** Aggregate function that calculates statistics on top of cross-tab: * - histogram of every argument and every pair of elements. * These statistics include: * - Cramer's V; * - Theil's U; * - contingency coefficient; * It can be interpreted as interdependency coefficient between arguments; * or non-parametric correlation coefficient. */ namespace DB { struct CrossTabData { /// Total count. UInt64 count = 0; /// Count of every value of the first and second argument (values are pre-hashed). /// Note: non-cryptographic 64bit hash is used, it means that the calculation is approximate. HashMapWithStackMemory count_a; HashMapWithStackMemory count_b; /// Count of every pair of values. We pack two hashes into UInt128. HashMapWithStackMemory count_ab; void add(UInt64 hash1, UInt64 hash2) { ++count; ++count_a[hash1]; ++count_b[hash2]; UInt128 hash_pair{hash1, hash2}; ++count_ab[hash_pair]; } void merge(const CrossTabData & other) { count += other.count; for (const auto & [key, value] : other.count_a) count_a[key] += value; for (const auto & [key, value] : other.count_b) count_b[key] += value; for (const auto & [key, value] : other.count_ab) count_ab[key] += value; } void serialize(WriteBuffer & buf) const { writeBinary(count, buf); count_a.write(buf); count_b.write(buf); count_ab.write(buf); } void deserialize(ReadBuffer & buf) { readBinary(count, buf); count_a.read(buf); count_b.read(buf); count_ab.read(buf); } /** See https://en.wikipedia.org/wiki/Cram%C3%A9r%27s_V * * φ² is χ² divided by the sample size (count). * χ² is the sum of squares of the normalized differences between the "expected" and "observed" statistics. * ("Expected" in the case when one of the hypotheses is true). * Something resembling the L2 distance. * * Note: statisticians use the name χ² for every statistic that has χ² distribution in many various contexts. * * Let's suppose that there is no association between the values a and b. * Then the frequency (e.g. probability) of (a, b) pair is equal to the multiplied frequencies of a and b: * count_ab / count = (count_a / count) * (count_b / count) * count_ab = count_a * count_b / count * * Let's calculate the difference between the values that are supposed to be equal if there is no association between a and b: * count_ab - count_a * count_b / count * * Let's sum the squares of the differences across all (a, b) pairs. * Then divide by the second term for normalization: (count_a * count_b / count) * * This will be the χ² statistics. * This statistics is used as a base for many other statistics. */ Float64 getPhiSquared() const { Float64 chi_squared = 0; for (const auto & [key, value_ab] : count_ab) { Float64 value_a = count_a.at(key.items[UInt128::_impl::little(0)]); Float64 value_b = count_b.at(key.items[UInt128::_impl::little(1)]); Float64 expected_value_ab = (value_a * value_b) / count; Float64 chi_squared_elem = value_ab - expected_value_ab; chi_squared_elem = chi_squared_elem * chi_squared_elem / expected_value_ab; chi_squared += chi_squared_elem; } return chi_squared / count; } }; template class AggregateFunctionCrossTab : public IAggregateFunctionDataHelper> { public: explicit AggregateFunctionCrossTab(const DataTypes & arguments) : IAggregateFunctionDataHelper>({arguments}, {}, createResultType()) { } String getName() const override { return Data::getName(); } bool allocatesMemoryInArena() const override { return false; } static DataTypePtr createResultType() { return std::make_shared>(); } void add( AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { UInt64 hash1 = UniqVariadicHash::apply(1, &columns[0], row_num); UInt64 hash2 = UniqVariadicHash::apply(1, &columns[1], row_num); this->data(place).add(hash1, hash2); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); } void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional) const override { this->data(place).serialize(buf); } void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional, Arena *) const override { this->data(place).deserialize(buf); } void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override { Float64 result = this->data(place).getResult(); auto & column = static_cast &>(to); column.getData().push_back(result); } }; }