diff --git a/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java b/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java index e868a84cc1..037790f0be 100644 --- a/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java +++ b/fluss-client/src/main/java/org/apache/fluss/client/initializer/BucketOffsetsRetrieverImpl.java @@ -37,10 +37,17 @@ public class BucketOffsetsRetrieverImpl implements OffsetsInitializer.BucketOffsetsRetriever { private final Admin flussAdmin; private final TablePath tablePath; + private final Boolean fetchEarliestOffset; public BucketOffsetsRetrieverImpl(Admin flussAdmin, TablePath tablePath) { + this(flussAdmin, tablePath, false); + } + + public BucketOffsetsRetrieverImpl( + Admin flussAdmin, TablePath tablePath, Boolean fetchEarliestOffset) { this.flussAdmin = flussAdmin; this.tablePath = tablePath; + this.fetchEarliestOffset = fetchEarliestOffset; } @Override @@ -52,11 +59,15 @@ public Map latestOffsets( @Override public Map earliestOffsets( @Nullable String partitionName, Collection buckets) { - Map bucketWithOffset = new HashMap<>(buckets.size()); - for (Integer bucket : buckets) { - bucketWithOffset.put(bucket, EARLIEST_OFFSET); + if (!fetchEarliestOffset) { + Map bucketWithOffset = new HashMap<>(buckets.size()); + for (Integer bucket : buckets) { + bucketWithOffset.put(bucket, EARLIEST_OFFSET); + } + return bucketWithOffset; + } else { + return listOffsets(partitionName, buckets, new OffsetSpec.EarliestSpec()); } - return bucketWithOffset; } @Override diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala index 28fb633b52..00d6400f64 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkFlussConf.scala @@ -50,4 +50,14 @@ object SparkFlussConf { .durationType() .defaultValue(Duration.ofMillis(10000L)) .withDescription("The timeout for log scanner to poll records.") + + val SCAN_MAX_RECORDS_PER_PARTITION: ConfigOption[java.lang.Long] = + ConfigBuilder + .key("scan.max.records.per.partition") + .longType() + .noDefaultValue() + .withDescription( + "The maximum number of records per Spark input partition when reading a log table. " + + "When set, each Fluss bucket whose offset range exceeds this value will be split " + + "into multiple partitions. Disabled by default (one partition per bucket).") } diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala index 0fceef7721..7e57efa383 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussBatch.scala @@ -25,6 +25,7 @@ import org.apache.fluss.client.table.scanner.log.LogScanner import org.apache.fluss.config.Configuration import org.apache.fluss.metadata.{PartitionInfo, TableBucket, TableInfo, TablePath} import org.apache.fluss.predicate.Predicate +import org.apache.fluss.spark.SparkFlussConf import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory} import org.apache.spark.sql.types.StructType @@ -91,26 +92,64 @@ class FlussAppendBatch( } override def planInputPartitions(): Array[InputPartition] = { - val bucketOffsetsRetrieverImpl = new BucketOffsetsRetrieverImpl(admin, tablePath) + val maxRecordsPerPartition: Option[Long] = { + val opt = flussConfig.getOptional(SparkFlussConf.SCAN_MAX_RECORDS_PER_PARTITION) + if (opt.isPresent) Some(opt.get().longValue()) else None + } + + val bucketOffsetsRetrieverImpl = maxRecordsPerPartition match { + case Some(_) => new BucketOffsetsRetrieverImpl(admin, tablePath, true) + case None => new BucketOffsetsRetrieverImpl(admin, tablePath) + } val buckets = (0 until tableInfo.getNumBuckets).toSeq + def splitOffsetRange( + tableBucket: TableBucket, + startOffset: Long, + stopOffset: Long, + maxRecords: Long): Seq[InputPartition] = { + if ( + startOffset < 0 || stopOffset <= startOffset || stopOffset <= (startOffset + maxRecords) + ) { + return Seq( + FlussAppendInputPartition(tableBucket, startOffset, stopOffset) + .asInstanceOf[InputPartition]) + } + val rangeSize = stopOffset - startOffset + val numSplits = ((rangeSize + maxRecords - 1) / maxRecords).toInt + val step = (rangeSize + numSplits - 1) / numSplits + + Iterator + .from(0) + .take(numSplits) + .map(i => startOffset + i * step) + .map { + from => + FlussAppendInputPartition(tableBucket, from, math.min(from + step, stopOffset)) + .asInstanceOf[InputPartition] + } + .toSeq + } + def createPartitions( partitionId: Option[Long], startBucketOffsets: Map[Integer, Long], stoppingBucketOffsets: Map[Integer, Long]): Array[InputPartition] = { - buckets.map { + buckets.flatMap { bucketId => - val (startBucketOffset, stoppingBucketOffset) = + val (startOffset, stopOffset) = (startBucketOffsets(bucketId), stoppingBucketOffsets(bucketId)) - partitionId match { - case Some(partitionId) => - val tableBucket = new TableBucket(tableInfo.getTableId, partitionId, bucketId) - FlussAppendInputPartition(tableBucket, startBucketOffset, stoppingBucketOffset) - .asInstanceOf[InputPartition] + val tableBucket = partitionId match { + case Some(pid) => new TableBucket(tableInfo.getTableId, pid, bucketId) + case None => new TableBucket(tableInfo.getTableId, bucketId) + } + maxRecordsPerPartition match { + case Some(maxRecs) => + splitOffsetRange(tableBucket, startOffset, stopOffset, maxRecs) case None => - val tableBucket = new TableBucket(tableInfo.getTableId, bucketId) - FlussAppendInputPartition(tableBucket, startBucketOffset, stoppingBucketOffset) - .asInstanceOf[InputPartition] + Seq( + FlussAppendInputPartition(tableBucket, startOffset, stopOffset) + .asInstanceOf[InputPartition]) } }.toArray } diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala index d616478aa4..ec733441eb 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala @@ -524,4 +524,13 @@ class SparkLogTableReadTest extends FlussSparkTestBase { assert(numRowsRead == 5L, s"Expected 5 rows read, got $numRowsRead") } } + + test("Spark Read: split partition by config") { + withSampleTable { + withSQLConf("spark.sql.fluss.scan.max.records.per.partition" -> "2") { + val query = sql(s"SELECT amount FROM $DEFAULT_DATABASE.t ORDER BY orderId") + checkAnswer(query, Row(601) :: Row(602) :: Row(603) :: Row(604) :: Row(605) :: Nil) + } + } + } }