diff --git a/src/antlr/Parser.g b/src/antlr/Parser.g index 8794d00d02a9..0f5a17193c34 100644 --- a/src/antlr/Parser.g +++ b/src/antlr/Parser.g @@ -807,6 +807,12 @@ rowDataReference returns [RowDataReference.Raw rawRef] : t=sident ('.' s=referenceSelection)? { tuple = t; selectable = s; } ; +indexedRowDataReference returns [RowDataReference.Raw rawRef] + @init { Selectable.RawIdentifier tuple = null; Selectable.Raw selectable = null; } + @after { $rawRef = newRowDataReference(tuple, selectable); } + : t=sident '.' s=referenceSelection { tuple = t; selectable = s; } + ; + referenceSelection returns [Selectable.Raw s] : g=referenceSelectionWithoutField m=selectorModifier[g] {$s = m;} ; @@ -2105,9 +2111,16 @@ normalColumnOperation[UpdateStatement.OperationCollector operations, ColumnIdent addRecognitionError("Only expressions of the form X = X " + ($i.text.charAt(0) == '-' ? '-' : '+') + " are supported."); addRawUpdate(operations, key, new Operation.Addition(Constants.Literal.integer($i.text))); } - | {isParsingTxn}? r=rowDataReference + | {isParsingTxn}? r=indexedRowDataReference (sig=('+'|'-') t=term)? { - addRawReferenceOperation(operations, key, new ReferenceOperation.Raw(new Operation.SetValue(r), key, new ReferenceValue.Substitution.Raw(r))); + if (t == null) + { + addRawReferenceOperation(operations, key, new ReferenceOperation.Raw(new Operation.SetValue(r), key, new ReferenceValue.Substitution.Raw(r))); + } + else + { + addRawReferenceOperation(operations, key, new ReferenceOperation.Raw($sig.text.equals("+") ? new Operation.Addition(t) : new Operation.Substraction(t), key, new ReferenceValue.Substitution.Raw(r))); + } } ; diff --git a/src/java/org/apache/cassandra/cql3/Operation.java b/src/java/org/apache/cassandra/cql3/Operation.java index cca7bb0d2b9c..58dfbca2c193 100644 --- a/src/java/org/apache/cassandra/cql3/Operation.java +++ b/src/java/org/apache/cassandra/cql3/Operation.java @@ -17,6 +17,7 @@ */ package org.apache.cassandra.cql3; +import java.nio.ByteBuffer; import java.util.List; import org.apache.cassandra.cql3.functions.Function; @@ -120,6 +121,11 @@ public void collectMarkerSpecification(VariableSpecifications boundNames, Object */ public abstract void execute(DecoratedKey partitionKey, RowUpdateBuilder builder) throws InvalidRequestException; + public void execute(DecoratedKey partitionKey, RowUpdateBuilder builder, ByteBuffer term) throws InvalidRequestException + { + throw new UnsupportedOperationException(); + } + /** * A parsed raw UPDATE operation. * diff --git a/src/java/org/apache/cassandra/cql3/statements/UpdateStatement.java b/src/java/org/apache/cassandra/cql3/statements/UpdateStatement.java index 11e60477c763..eb3b0d485b38 100644 --- a/src/java/org/apache/cassandra/cql3/statements/UpdateStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/UpdateStatement.java @@ -214,7 +214,7 @@ else if (value instanceof ReferenceValue.Raw) { ReferenceValue.Raw raw = (ReferenceValue.Raw) value; ReferenceValue referenceValue = raw.prepare(def, bindVariables); - ReferenceOperation operation = new ReferenceOperation(def, metadata, TxnReferenceOperation.Kind.setterFor(def), null, null, referenceValue); + ReferenceOperation operation = new ReferenceOperation(def, metadata, TxnReferenceOperation.Kind.setterFor(def), null, null, null, referenceValue); operations.add(def, operation); } else diff --git a/src/java/org/apache/cassandra/cql3/terms/Constants.java b/src/java/org/apache/cassandra/cql3/terms/Constants.java index 655608291651..fdd4cc369a91 100644 --- a/src/java/org/apache/cassandra/cql3/terms/Constants.java +++ b/src/java/org/apache/cassandra/cql3/terms/Constants.java @@ -21,6 +21,8 @@ import java.math.BigInteger; import java.nio.ByteBuffer; +import accord.utils.Invariants; + import org.apache.cassandra.cql3.AssignmentTestable; import org.apache.cassandra.cql3.CQL3Type; import org.apache.cassandra.cql3.ColumnSpecification; @@ -550,6 +552,30 @@ else if (column.type instanceof StringType) builder.addCell(column, newValue); } } + + public void execute(DecoratedKey partitionKey, RowUpdateBuilder builder, ByteBuffer constant) throws InvalidRequestException + { + Invariants.require(constant != null); + if (column.type instanceof NumberType) + { + @SuppressWarnings("unchecked") NumberType type = (NumberType) column.type; + ByteBuffer increment = type.sanitize(t.bindAndGet(builder)); + if (increment == null) + return; + ByteBuffer newValue = type.add(type.compose(constant), type.compose(increment)); + builder.addCell(column, newValue); + } + else if (column.type instanceof StringType) + { + ByteBuffer left = t.bindAndGet(builder); + if (left == null) + return; + ByteBuffer newValue = ByteBuffer.allocate(left.remaining() + constant.remaining()); + FastByteOperations.copy(left, left.position(), newValue, newValue.position(), left.remaining()); + FastByteOperations.copy(constant, constant.position(), newValue, newValue.position() + left.remaining(), constant.remaining()); + builder.addCell(column, newValue); + } + } } public static class Substracter extends Operation @@ -594,6 +620,20 @@ else if (column.type instanceof NumberType) builder.addCell(column, newValue); } } + + public void execute(DecoratedKey partitionKey, RowUpdateBuilder builder, ByteBuffer constant) throws InvalidRequestException + { + Invariants.require(constant != null); + if (column.type instanceof NumberType) + { + @SuppressWarnings("unchecked") NumberType type = (NumberType) column.type; + ByteBuffer increment = type.sanitize(t.bindAndGet(builder)); + if (increment == null) + return; + ByteBuffer newValue = type.substract(type.compose(increment), type.compose(constant)); + builder.addCell(column, newValue); + } + } } // This happens to also handle collection because it doesn't felt worth diff --git a/src/java/org/apache/cassandra/cql3/transactions/ReferenceOperation.java b/src/java/org/apache/cassandra/cql3/transactions/ReferenceOperation.java index 2b4b6f999f0f..554d443dedcc 100644 --- a/src/java/org/apache/cassandra/cql3/transactions/ReferenceOperation.java +++ b/src/java/org/apache/cassandra/cql3/transactions/ReferenceOperation.java @@ -48,21 +48,23 @@ public class ReferenceOperation private final TableMetadata table; private final TxnReferenceOperation.Kind kind; private final FieldIdentifier field; + private final Term constant; private final Term key; private final ReferenceValue value; - public ReferenceOperation(ColumnMetadata receiver, TableMetadata table, TxnReferenceOperation.Kind kind, Term key, FieldIdentifier field, ReferenceValue value) + public ReferenceOperation(ColumnMetadata receiver, TableMetadata table, TxnReferenceOperation.Kind kind, Term key, FieldIdentifier field, Term constant, ReferenceValue value) { this.receiver = receiver; this.table = table; this.kind = kind; this.key = key; this.field = field; + this.constant = constant; this.value = value; } /** - * Creates a {@link ReferenceOperation} from the given {@link Operation} for the purpose of defering execution + * Creates a {@link ReferenceOperation} from the given {@link Operation} for the purpose of defering execution * within a transaction. When the language sees an Operation using a reference one is created already, but for cases * that needs to defer execution (such as when {@link Operation#requiresRead()} is true), this method can be used. */ @@ -75,7 +77,7 @@ public static ReferenceOperation create(Operation operation, TableMetadata table ReferenceValue value = new ReferenceValue.Constant(operation.term()); Term key = extractKeyOrIndex(operation); FieldIdentifier field = extractField(operation); - return new ReferenceOperation(receiver, table, kind, key, field, value); + return new ReferenceOperation(receiver, table, kind, key, field, null, value); } public TxnReferenceOperation.Kind getKind() @@ -105,6 +107,7 @@ public TxnReferenceOperation bindAndGet(QueryOptions options) receiver, table, key != null ? key.bindAndGet(options) : null, field != null ? field.bytes : null, + constant != null ? constant.bindAndGet(options) : null, value.bindAndGet(options)); } @@ -157,7 +160,11 @@ public ReferenceOperation prepare(TableMetadata metadata, VariableSpecifications } } - return new ReferenceOperation(receiver, metadata, kind, key, field, value.prepare(valueReceiver, bindVariables)); + ReferenceValue referenceValue = value.prepare(valueReceiver, bindVariables); + + // When operation.term().equals(referenceValue.getTerm()), we are in the case where we have v += row1.c and + // when this is not true, we are in the case where we have v = row1.c + 3 + return new ReferenceOperation(receiver, metadata, kind, key, field, operation.term().equals(referenceValue.getTerm()) ? null : operation.term(), referenceValue); } } diff --git a/src/java/org/apache/cassandra/cql3/transactions/ReferenceValue.java b/src/java/org/apache/cassandra/cql3/transactions/ReferenceValue.java index 5d4fb64e97ec..c6ded2af0bf6 100644 --- a/src/java/org/apache/cassandra/cql3/transactions/ReferenceValue.java +++ b/src/java/org/apache/cassandra/cql3/transactions/ReferenceValue.java @@ -33,6 +33,8 @@ public abstract class ReferenceValue { public abstract TxnReferenceValue bindAndGet(FunctionContext context); + public abstract Term getTerm(); + public static abstract class Raw extends Term.Raw { public abstract ReferenceValue prepare(ColumnMetadata receiver, VariableSpecifications bindVariables); @@ -53,6 +55,12 @@ public TxnReferenceValue bindAndGet(FunctionContext context) return new TxnReferenceValue.Constant(term.bindAndGet(context)); } + @Override + public Term getTerm() + { + return term; + } + public static class Raw extends ReferenceValue.Raw { private final Term.Raw term; @@ -109,6 +117,12 @@ public TxnReferenceValue bindAndGet(FunctionContext context) return new TxnReferenceValue.Substitution(reference.toTxnReference(context).asColumn()); } + @Override + public Term getTerm() + { + return reference; + } + public static class Raw extends ReferenceValue.Raw { private final RowDataReference.Raw reference; diff --git a/src/java/org/apache/cassandra/cql3/transactions/RowDataReference.java b/src/java/org/apache/cassandra/cql3/transactions/RowDataReference.java index 0d376439346e..94e4d0767ae3 100644 --- a/src/java/org/apache/cassandra/cql3/transactions/RowDataReference.java +++ b/src/java/org/apache/cassandra/cql3/transactions/RowDataReference.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.Map; +import java.util.Objects; import com.google.common.base.Preconditions; @@ -101,6 +102,19 @@ public void addFunctionsTo(List functions) throw new UnsupportedOperationException("Functions are not currently supported w/ reference terms."); } + @Override + public boolean equals(Object o) + { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RowDataReference that = (RowDataReference) o; + return Objects.equals(txnDataName, that.txnDataName) + && Objects.equals(column, that.column) + && Objects.equals(table, that.table) + && Objects.equals(elementPath, that.elementPath) + && Objects.equals(fieldPath, that.fieldPath); + } + public ColumnMetadata toResultMetadata() { ColumnIdentifier fullName = getFullyQualifiedName(); diff --git a/src/java/org/apache/cassandra/service/accord/txn/TxnReferenceOperation.java b/src/java/org/apache/cassandra/service/accord/txn/TxnReferenceOperation.java index 7d47f8afd2bc..29816f846bad 100644 --- a/src/java/org/apache/cassandra/service/accord/txn/TxnReferenceOperation.java +++ b/src/java/org/apache/cassandra/service/accord/txn/TxnReferenceOperation.java @@ -170,18 +170,20 @@ public Operation toOperation(ColumnMetadata column, Term keyOrIndex, FieldIdenti public final TableMetadata table; private final @Nullable ByteBuffer keyOrIndex; private final @Nullable ByteBuffer field; + private final @Nullable ByteBuffer constant; private final TxnReferenceValue value; private final @Nullable AbstractType keyOrIndexType; private final AbstractType valueType; public TxnReferenceOperation(Kind kind, ColumnMetadata receiver, TableMetadata table, - @Nullable ByteBuffer keyOrIndex, @Nullable ByteBuffer field, TxnReferenceValue value) + @Nullable ByteBuffer keyOrIndex, @Nullable ByteBuffer field, @Nullable ByteBuffer constant, TxnReferenceValue value) { this.kind = kind; this.receiver = receiver; this.table = table; this.keyOrIndex = keyOrIndex; this.field = field; + this.constant = constant; // We don't expect operators on clustering keys, but unwrap just in case. AbstractType receiverType = receiver.type.unwrap(); @@ -272,7 +274,11 @@ public ColumnMetadata receiver() public void apply(TxnData data, DecoratedKey key, RowUpdateBuilder up) { Operation operation = toOperation(data); - operation.execute(key, up); + // When constant != null, we are performing a computation with a LET variable (i.e. row1.v + 2) + if (constant != null) + operation.execute(key, up, constant); + else + operation.execute(key, up); } @VisibleForTesting @@ -306,10 +312,16 @@ else if (receivingType.isTuple()) static final ParameterisedUnversionedSerializer serializer = new ParameterisedUnversionedSerializer<>() { + private static final int TOP_BIT = 0x40; + @Override public void serialize(TxnReferenceOperation operation, TableMetadatas tables, DataOutputPlus out) throws IOException { - out.writeByte(operation.kind.id); + if (operation.constant != null) + out.writeUnsignedVInt32(operation.kind.id | TOP_BIT); + else + out.writeUnsignedVInt32(operation.kind.id); + tables.serialize(operation.table, out); columnMetadataSerializer.serialize(operation.receiver, operation.table, out); TxnReferenceValue.serializer.serialize(operation.value, tables, out); @@ -321,25 +333,41 @@ public void serialize(TxnReferenceOperation operation, TableMetadatas tables, Da out.writeBoolean(operation.field != null); if (operation.field != null) ByteBufferUtil.writeWithVIntLength(operation.field, out); + + // The boolean for whether operation.constant is null is encoded + // in the TOP_BIT of operation.kind.id, this is to ensure that everything + // serialized by the new nodes can be deserialized by the old nodes modulo + // the new CQL syntax allowing calcuations with LET variables within + // the update statement + if (operation.constant != null) + ByteBufferUtil.writeWithVIntLength(operation.constant, out); } @Override public TxnReferenceOperation deserialize(TableMetadatas tables, DataInputPlus in) throws IOException { - Kind kind = Kind.from(in.readByte()); + int flags = in.readUnsignedVInt32(); + Kind kind; + if ((flags & TOP_BIT) != 0) + kind = Kind.from((byte) (flags ^ TOP_BIT)); + else + kind = Kind.from((byte) (flags)); TableMetadata table = tables.deserialize(in); ColumnMetadata receiver = columnMetadataSerializer.deserialize(table, in); TxnReferenceValue value = TxnReferenceValue.serializer.deserialize(tables, in); ByteBuffer key = in.readBoolean() ? ByteBufferUtil.readWithVIntLength(in) : null; ByteBuffer field = in.readBoolean() ? ByteBufferUtil.readWithVIntLength(in) : null; - return new TxnReferenceOperation(kind, receiver, table, key, field, value); + ByteBuffer constant = null; + if ((flags & TOP_BIT) != 0) + constant = ByteBufferUtil.readWithVIntLength(in); + return new TxnReferenceOperation(kind, receiver, table, key, field, constant, value); } @Override public long serializedSize(TxnReferenceOperation operation, TableMetadatas tables) { - long size = Byte.BYTES; - size += tables.serializedSize(operation.table); + long size = TypeSizes.sizeofUnsignedVInt(operation.kind.id | TOP_BIT); + size += tables.serializedSize(operation.table); size += columnMetadataSerializer.serializedSize(operation.receiver, operation.table); size += TxnReferenceValue.serializer.serializedSize(operation.value, tables); @@ -351,6 +379,9 @@ public long serializedSize(TxnReferenceOperation operation, TableMetadatas table if (operation.field != null) size += ByteBufferUtil.serializedSizeWithVIntLength(operation.field); + if (operation.constant != null) + size += ByteBufferUtil.serializedSizeWithVIntLength(operation.constant); + return size; } }; diff --git a/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java b/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java index f38b84ede287..842e3d058076 100644 --- a/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java +++ b/test/distributed/org/apache/cassandra/distributed/test/accord/AccordCQLTestBase.java @@ -84,6 +84,7 @@ import org.apache.cassandra.distributed.test.sai.SAIUtil; import org.apache.cassandra.distributed.util.QueryResultUtil; import org.apache.cassandra.exceptions.InvalidRequestException; +import org.apache.cassandra.exceptions.SyntaxException; import org.apache.cassandra.exceptions.WriteTimeoutException; import org.apache.cassandra.io.util.DataInputBuffer; import org.apache.cassandra.schema.SchemaConstants; @@ -3604,4 +3605,68 @@ public void userSeesInvalidRejection() throws Exception .hasMessage("Attempted to set an element on a list which is null"); }); } + + @Test + public void testLETVariableReferenceInUpdateFails() throws Exception + { + // Regression test for prior NPE + test("CREATE TABLE " + qualifiedAccordTableName + " (k int, c int, v int, PRIMARY KEY (k, c)) WITH " + transactionalMode.asCqlParam(), cluster -> { + try + { + String txn = "BEGIN TRANSACTION\n" + + " LET r = (SELECT v FROM " + qualifiedAccordTableName + " WHERE k = 1 AND c = 1);\n" + + " UPDATE " + qualifiedAccordTableName + " SET v = r WHERE k=1 AND c=1;\n" + + "COMMIT TRANSACTION"; + + cluster.coordinator(1).executeWithResult(txn, ConsistencyLevel.SERIAL); + fail("Expected exception"); + } + catch (Throwable t) + { + assertEquals(SyntaxException.class.getName(), t.getClass().getName()); + } + }); + } + + @Test + public void testUseLetVariableForEvaluationWithInt() throws Exception + { + test("CREATE TABLE " + qualifiedAccordTableName + " (k int, c int, v int, PRIMARY KEY (k, c)) WITH " + transactionalMode.asCqlParam(), cluster -> { + cluster.coordinator(1).execute("INSERT INTO " + qualifiedAccordTableName + " (k, c, v) VALUES (1, 1, 5)", ConsistencyLevel.ALL); + + String update = "BEGIN TRANSACTION\n" + + " LET row1 = (SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 1 AND c = 1);\n" + + " UPDATE " + qualifiedAccordTableName + " SET v = row1.v + 3 WHERE k = 1 AND c = 1;\n" + + "COMMIT TRANSACTION"; + cluster.coordinator(1).executeWithResult(update, ConsistencyLevel.ALL); + + String read = "BEGIN TRANSACTION\n" + + "SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 1;\n" + + "COMMIT TRANSACTION"; + + SimpleQueryResult result = cluster.coordinator(1).executeWithResult(read, ConsistencyLevel.SERIAL); + assertThat(result).hasSize(1).contains(1, 1, 8); + }); + } + + @Test + public void testUseLetVariableForEvaluationWithString() throws Exception + { + test("CREATE TABLE " + qualifiedAccordTableName + " (k int, c int, v text, PRIMARY KEY (k, c)) WITH " + transactionalMode.asCqlParam(), cluster -> { + cluster.coordinator(1).execute("INSERT INTO " + qualifiedAccordTableName + " (k, c, v) VALUES (?, ?, ?)", ConsistencyLevel.ALL, 1, 1, "a"); + + String update = "BEGIN TRANSACTION\n" + + " LET row1 = (SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 1 AND c = 1);\n" + + " UPDATE " + qualifiedAccordTableName + " SET v = row1.v + ? WHERE k = 1 AND c = 1;\n" + + "COMMIT TRANSACTION"; + cluster.coordinator(1).executeWithResult(update, ConsistencyLevel.ALL, "m"); + + String read = "BEGIN TRANSACTION\n" + + "SELECT * FROM " + qualifiedAccordTableName + " WHERE k = 1;\n" + + "COMMIT TRANSACTION"; + + SimpleQueryResult result = cluster.coordinator(1).executeWithResult(read, ConsistencyLevel.SERIAL); + assertThat(result).hasSize(1).contains(1, 1, "am"); + }); + } } diff --git a/test/unit/org/apache/cassandra/service/accord/txn/TxnReferenceOperationTest.java b/test/unit/org/apache/cassandra/service/accord/txn/TxnReferenceOperationTest.java index 0b271e2b718f..82f915a7c632 100644 --- a/test/unit/org/apache/cassandra/service/accord/txn/TxnReferenceOperationTest.java +++ b/test/unit/org/apache/cassandra/service/accord/txn/TxnReferenceOperationTest.java @@ -123,6 +123,7 @@ private static Gen gen() TableMetadata table; @Nullable ByteBuffer keyOrIndex = null; @Nullable ByteBuffer field = null; + @Nullable ByteBuffer constant = null; TxnReferenceValue value; Group group = rs.pick(Group.values()); switch (group) @@ -195,6 +196,8 @@ else if (type instanceof SetType && rs.nextBoolean()) table = table(type); receiver = table.getColumn(ColumnIdentifier.getInterned("col", true)); value = valueGen(type).next(rs); + if (rs.nextBoolean()) + constant = Generators.toGen(AbstractTypeGenerators.getTypeSupport(type).bytesGen()).next(rs); kind = group == Group.Adder ? TxnReferenceOperation.Kind.ConstantAdder : TxnReferenceOperation.Kind.ConstantSubtracter; } } @@ -276,7 +279,8 @@ else if (type instanceof UserType) default: throw new UnsupportedOperationException(); } - return new TxnReferenceOperation(kind, receiver, table, keyOrIndex, field, value); + + return new TxnReferenceOperation(kind, receiver, table, keyOrIndex, field, constant, value); }; }