From d1dd5ce3261f8ad1815657ed4d3dbeebddcb041b Mon Sep 17 00:00:00 2001 From: Junbo Wang Date: Fri, 8 May 2026 15:52:13 +0800 Subject: [PATCH] [flink] Fix COUNT(column) aggregate pushdown to reject nullable columns --- .../fluss/flink/source/FlinkTableSource.java | 37 +++++++---- .../source/FlinkTableSourceBatchITCase.java | 64 +++++++++++++++++-- 2 files changed, 85 insertions(+), 16 deletions(-) diff --git a/fluss-flink/fluss-flink-common/src/main/java/org/apache/fluss/flink/source/FlinkTableSource.java b/fluss-flink/fluss-flink-common/src/main/java/org/apache/fluss/flink/source/FlinkTableSource.java index 0da1ba7668..edbf667705 100644 --- a/fluss-flink/fluss-flink-common/src/main/java/org/apache/fluss/flink/source/FlinkTableSource.java +++ b/fluss-flink/fluss-flink-common/src/main/java/org/apache/fluss/flink/source/FlinkTableSource.java @@ -78,7 +78,6 @@ import org.apache.flink.table.expressions.AggregateExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.functions.AsyncLookupFunction; -import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.LookupFunction; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.LogicalType; @@ -799,19 +798,33 @@ public boolean applyAggregates( return false; } - FunctionDefinition functionDefinition = aggregateExpressions.get(0).getFunctionDefinition(); - if (!(functionDefinition - .getClass() - .getCanonicalName() - .equals( - "org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction") - || functionDefinition - .getClass() - .getCanonicalName() - .equals( - "org.apache.flink.table.planner.functions.aggfunctions.Count1AggFunction"))) { + AggregateExpression aggExpr = aggregateExpressions.get(0); + String functionName = aggExpr.getFunctionDefinition().getClass().getCanonicalName(); + + // Verify that the aggregate function is COUNT(*) or COUNT(1) + // CountAggFunction: COUNT(*) or COUNT(column) + // Count1AggFunction: COUNT(1) with constant argument + boolean isCountAgg = + "org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction" + .equals(functionName); + boolean isCount1Agg = + "org.apache.flink.table.planner.functions.aggfunctions.Count1AggFunction" + .equals(functionName); + if (!isCountAgg && !isCount1Agg) { return false; } + + // For COUNT(column), reject if column is nullable (cannot handle NULL filtering) + if (isCountAgg) { + List args = aggExpr.getChildren(); + if (!args.isEmpty() && args.get(0) instanceof ResolvedExpression) { + ResolvedExpression arg = (ResolvedExpression) args.get(0); + if (arg.getOutputDataType().getLogicalType().isNullable()) { + return false; + } + } + } + selectRowCount = true; this.producedDataType = dataType.getLogicalType(); return true; diff --git a/fluss-flink/fluss-flink-common/src/test/java/org/apache/fluss/flink/source/FlinkTableSourceBatchITCase.java b/fluss-flink/fluss-flink-common/src/test/java/org/apache/fluss/flink/source/FlinkTableSourceBatchITCase.java index 48ee39af11..e7ac69b0c4 100644 --- a/fluss-flink/fluss-flink-common/src/test/java/org/apache/fluss/flink/source/FlinkTableSourceBatchITCase.java +++ b/fluss-flink/fluss-flink-common/src/test/java/org/apache/fluss/flink/source/FlinkTableSourceBatchITCase.java @@ -404,6 +404,33 @@ void testCountPushDownForPkTable(boolean partitionTable) throws Exception { List expected = Collections.singletonList("+I[5]"); assertThat(collected).isEqualTo(expected); + // test COUNT(column) pushdown on non-nullable column + query = String.format("SELECT COUNT(id) FROM %s", tableName); + assertThat(tEnv.explainSql(query)) + .contains( + "aggregates=[grouping=[], aggFunctions=[Count1AggFunction()]]]], fields=[count1$0]"); + iterRows = tEnv.executeSql(query).collect(); + collected = collectRowsWithTimeout(iterRows, 1); + assertThat(collected).isEqualTo(expected); + + // test COUNT(column) on nullable column - should NOT push down + // For PK table, this will fail because it doesn't support full scan in batch mode + assertThatThrownBy( + () -> + tEnv.explainSql( + String.format("SELECT COUNT(address) FROM %s", tableName))) + .hasMessageContaining( + "Currently, Fluss only support queries on table with datalake enabled or point queries on primary key when it's in batch execution mode."); + + assertThatThrownBy( + () -> + tEnv.explainSql( + String.format( + "SELECT COUNT(DISTINCT address) FROM %s", + tableName))) + .hasMessageContaining( + "Currently, Fluss only support queries on table with datalake enabled or point queries on primary key when it's in batch execution mode."); + // test not push down grouping count. assertThatThrownBy( () -> @@ -452,6 +479,32 @@ void testCountPushDownForLogTable(boolean partitionTable) throws Exception { List expected = Collections.singletonList(String.format("+I[%s]", expectedRows)); assertThat(collected).isEqualTo(expected); + // test COUNT(column) pushdown + query = String.format("SELECT COUNT(id) FROM %s", tableName); + assertThat(tEnv.explainSql(query)) + .contains( + "aggregates=[grouping=[], aggFunctions=[Count1AggFunction()]]]], fields=[count1$0]"); + iterRows = tEnv.executeSql(query).collect(); + collected = collectRowsWithTimeout(iterRows, 1); + assertThat(collected).isEqualTo(expected); + + // test COUNT(column) with NULL values - should NOT push down for nullable columns + // This will fail because log table doesn't support full scan in batch mode + assertThatThrownBy( + () -> + tEnv.explainSql( + String.format("SELECT COUNT(address) FROM %s", tableName))) + .hasMessageContaining( + "Currently, Fluss only support queries on table with datalake enabled or point queries on primary key when it's in batch execution mode."); + assertThatThrownBy( + () -> + tEnv.explainSql( + String.format( + "SELECT COUNT(DISTINCT address) FROM %s", + tableName))) + .hasMessageContaining( + "Currently, Fluss only support queries on table with datalake enabled or point queries on primary key when it's in batch execution mode."); + // test not push down grouping count. assertThatThrownBy( () -> @@ -536,11 +589,11 @@ private String prepareLogTable() throws Exception { TablePath tablePath = TablePath.of(DEFAULT_DB, tableName); - // prepare table data + // prepare table data with NULL values in address column try (Table table = conn.getTable(tablePath)) { AppendWriter appendWriter = table.newAppend().createWriter(); for (int i = 1; i <= 5; i++) { - Object[] values = new Object[] {i, "address" + i, "name" + i}; + Object[] values = new Object[] {i, i % 2 == 0 ? null : "address" + i, "name" + i}; appendWriter.append(row(values)); // make sure every bucket has records appendWriter.flush(); @@ -571,12 +624,15 @@ protected String preparePartitionedLogTable() throws Exception { waitUntilPartitions(FLUSS_CLUSTER_EXTENSION.getZooKeeperClient(), tablePath); Collection partitions = partitionNameById.values(); - // prepare table data + // prepare table data with NULL values in address column try (Table table = conn.getTable(tablePath)) { AppendWriter appendWriter = table.newAppend().createWriter(); for (int i = 1; i <= 5; i++) { for (String partition : partitions) { - Object[] values = new Object[] {i, "address" + i, "name" + i, partition}; + Object[] values = + new Object[] { + i, i % 2 == 0 ? null : "address" + i, "name" + i, partition + }; appendWriter.append(row(values)); // make sure every bucket has records appendWriter.flush();