Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
import org.apache.flink.table.expressions.AggregateExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.functions.AsyncLookupFunction;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.LookupFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
Expand Down Expand Up @@ -799,19 +798,33 @@ public boolean applyAggregates(
return false;
}

FunctionDefinition functionDefinition = aggregateExpressions.get(0).getFunctionDefinition();
if (!(functionDefinition
.getClass()
.getCanonicalName()
.equals(
"org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction")
|| functionDefinition
.getClass()
.getCanonicalName()
.equals(
"org.apache.flink.table.planner.functions.aggfunctions.Count1AggFunction"))) {
AggregateExpression aggExpr = aggregateExpressions.get(0);
String functionName = aggExpr.getFunctionDefinition().getClass().getCanonicalName();

// Verify that the aggregate function is COUNT(*) or COUNT(1)
// CountAggFunction: COUNT(*) or COUNT(column)
// Count1AggFunction: COUNT(1) with constant argument
boolean isCountAgg =
"org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction"
.equals(functionName);
boolean isCount1Agg =
"org.apache.flink.table.planner.functions.aggfunctions.Count1AggFunction"
.equals(functionName);
if (!isCountAgg && !isCount1Agg) {
return false;
}

// For COUNT(column), reject if column is nullable (cannot handle NULL filtering)
if (isCountAgg) {
List<org.apache.flink.table.expressions.Expression> args = aggExpr.getChildren();
if (!args.isEmpty() && args.get(0) instanceof ResolvedExpression) {
ResolvedExpression arg = (ResolvedExpression) args.get(0);
if (arg.getOutputDataType().getLogicalType().isNullable()) {
return false;
}
}
}
Comment on lines +817 to +826

selectRowCount = true;
this.producedDataType = dataType.getLogicalType();
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,33 @@ void testCountPushDownForPkTable(boolean partitionTable) throws Exception {
List<String> expected = Collections.singletonList("+I[5]");
assertThat(collected).isEqualTo(expected);

// test COUNT(column) pushdown on non-nullable column
query = String.format("SELECT COUNT(id) FROM %s", tableName);
assertThat(tEnv.explainSql(query))
.contains(
"aggregates=[grouping=[], aggFunctions=[Count1AggFunction()]]]], fields=[count1$0]");
iterRows = tEnv.executeSql(query).collect();
collected = collectRowsWithTimeout(iterRows, 1);
assertThat(collected).isEqualTo(expected);

// test COUNT(column) on nullable column - should NOT push down
// For PK table, this will fail because it doesn't support full scan in batch mode
assertThatThrownBy(
() ->
tEnv.explainSql(
String.format("SELECT COUNT(address) FROM %s", tableName)))
.hasMessageContaining(
"Currently, Fluss only support queries on table with datalake enabled or point queries on primary key when it's in batch execution mode.");

assertThatThrownBy(
() ->
tEnv.explainSql(
String.format(
"SELECT COUNT(DISTINCT address) FROM %s",
tableName)))
.hasMessageContaining(
"Currently, Fluss only support queries on table with datalake enabled or point queries on primary key when it's in batch execution mode.");

// test not push down grouping count.
assertThatThrownBy(
() ->
Expand Down Expand Up @@ -452,6 +479,32 @@ void testCountPushDownForLogTable(boolean partitionTable) throws Exception {
List<String> expected = Collections.singletonList(String.format("+I[%s]", expectedRows));
assertThat(collected).isEqualTo(expected);

// test COUNT(column) pushdown
query = String.format("SELECT COUNT(id) FROM %s", tableName);
assertThat(tEnv.explainSql(query))
.contains(
"aggregates=[grouping=[], aggFunctions=[Count1AggFunction()]]]], fields=[count1$0]");
iterRows = tEnv.executeSql(query).collect();
collected = collectRowsWithTimeout(iterRows, 1);
assertThat(collected).isEqualTo(expected);

// test COUNT(column) with NULL values - should NOT push down for nullable columns
// This will fail because log table doesn't support full scan in batch mode
assertThatThrownBy(
() ->
tEnv.explainSql(
String.format("SELECT COUNT(address) FROM %s", tableName)))
.hasMessageContaining(
"Currently, Fluss only support queries on table with datalake enabled or point queries on primary key when it's in batch execution mode.");
assertThatThrownBy(
() ->
tEnv.explainSql(
String.format(
"SELECT COUNT(DISTINCT address) FROM %s",
tableName)))
.hasMessageContaining(
"Currently, Fluss only support queries on table with datalake enabled or point queries on primary key when it's in batch execution mode.");
Comment on lines +491 to +506

// test not push down grouping count.
assertThatThrownBy(
() ->
Expand Down Expand Up @@ -536,11 +589,11 @@ private String prepareLogTable() throws Exception {

TablePath tablePath = TablePath.of(DEFAULT_DB, tableName);

// prepare table data
// prepare table data with NULL values in address column
try (Table table = conn.getTable(tablePath)) {
AppendWriter appendWriter = table.newAppend().createWriter();
for (int i = 1; i <= 5; i++) {
Object[] values = new Object[] {i, "address" + i, "name" + i};
Object[] values = new Object[] {i, i % 2 == 0 ? null : "address" + i, "name" + i};
appendWriter.append(row(values));
// make sure every bucket has records
appendWriter.flush();
Comment on lines +592 to 599
Expand Down Expand Up @@ -571,12 +624,15 @@ protected String preparePartitionedLogTable() throws Exception {
waitUntilPartitions(FLUSS_CLUSTER_EXTENSION.getZooKeeperClient(), tablePath);
Collection<String> partitions = partitionNameById.values();

// prepare table data
// prepare table data with NULL values in address column
try (Table table = conn.getTable(tablePath)) {
AppendWriter appendWriter = table.newAppend().createWriter();
for (int i = 1; i <= 5; i++) {
for (String partition : partitions) {
Object[] values = new Object[] {i, "address" + i, "name" + i, partition};
Object[] values =
new Object[] {
i, i % 2 == 0 ? null : "address" + i, "name" + i, partition
};
appendWriter.append(row(values));
// make sure every bucket has records
appendWriter.flush();
Expand Down
Loading