diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 38d55a54..139d5f4e 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -1,9 +1,7 @@ import thrift from 'thrift'; -import Int64 from 'node-int64'; import { EventEmitter } from 'events'; import TCLIService from '../thrift/TCLIService'; -import { TProtocolVersion } from '../thrift/TCLIService_types'; import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient'; import IDriver from './contracts/IDriver'; import IClientContext, { ClientConfig } from './contracts/IClientContext'; @@ -14,9 +12,11 @@ import IDBSQLSession from './contracts/IDBSQLSession'; import IAuthentication from './connection/contracts/IAuthentication'; import HttpConnection from './connection/connections/HttpConnection'; import IConnectionOptions from './connection/contracts/IConnectionOptions'; -import Status from './dto/Status'; import HiveDriverError from './errors/HiveDriverError'; -import { buildUserAgentString, definedOrError, serializeQueryTags } from './utils'; +import { buildUserAgentString } from './utils'; +import IBackend from './contracts/IBackend'; +import ThriftBackend from './thrift-backend/ThriftBackend'; +import SeaBackend from './sea/SeaBackend'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; import DatabricksOAuth, { OAuthFlow } from './connection/auth/DatabricksOAuth'; import { @@ -39,19 +39,6 @@ function prependSlash(str: string): string { return str; } -function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) { - if (!catalogName && !schemaName) { - return {}; - } - - return { - initialNamespace: { - catalogName, - schemaName, - }, - }; -} - export type ThriftLibrary = Pick; export default class DBSQLClient extends EventEmitter implements IDBSQLClient, IClientContext { @@ -75,6 +62,8 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I private readonly sessions = new CloseableCollection(); + private backend?: IBackend; + private static getDefaultLogger(): IDBSQLLogger { if (!this.defaultLogger) { this.defaultLogger = new DBSQLLogger(); @@ -248,38 +237,45 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I this.connectionProvider = this.createConnectionProvider(options); - const thriftConnection = await this.connectionProvider.getThriftConnection(); - - thriftConnection.on('error', (error: Error) => { - // Error.stack already contains error type and message, so log stack if available, - // otherwise fall back to just error type + message - this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`); - try { - this.emit('error', error); - } catch (e) { - // EventEmitter will throw unhandled error when emitting 'error' event. - // Since we already logged it few lines above, just suppress this behaviour - } - }); - - thriftConnection.on('reconnecting', (params: { delay: number; attempt: number }) => { - this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`); - this.emit('reconnecting', params); - }); - - thriftConnection.on('close', () => { - this.logger.log(LogLevel.debug, 'Closing connection.'); - this.emit('close'); - }); + this.backend = options.useSEA + ? new SeaBackend() + : new ThriftBackend({ + context: this, + onConnectionEvent: (event, payload) => this.forwardConnectionEvent(event, payload), + }); - thriftConnection.on('timeout', () => { - this.logger.log(LogLevel.debug, 'Connection timed out.'); - this.emit('timeout'); - }); + await this.backend.connect(options); return this; } + private forwardConnectionEvent(event: 'error' | 'reconnecting' | 'close' | 'timeout', payload?: unknown): void { + switch (event) { + case 'error': { + const error = payload as Error; + this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`); + try { + this.emit('error', error); + } catch (e) { + // EventEmitter throws when 'error' has no listeners; we've already logged it. + } + return; + } + case 'reconnecting': + this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(payload)}`); + this.emit('reconnecting', payload); + return; + case 'close': + this.logger.log(LogLevel.debug, 'Closing connection.'); + this.emit('close'); + return; + case 'timeout': + this.logger.log(LogLevel.debug, 'Connection timed out.'); + this.emit('timeout'); + // no default + } + } + /** * Starts new session * @public @@ -290,44 +286,20 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I * const session = await client.openSession(); */ public async openSession(request: OpenSessionRequest = {}): Promise { - // Prepare session configuration - const configuration = request.configuration ? { ...request.configuration } : {}; - - // Add metric view metadata config if enabled - if (this.config.enableMetricViewMetadata) { - configuration['spark.sql.thriftserver.metadata.metricview.enabled'] = 'true'; - } - - // Serialize queryTags dict and set in configuration; takes precedence over configuration.QUERY_TAGS - if (request.queryTags !== undefined) { - const serialized = serializeQueryTags(request.queryTags); - if (serialized) { - configuration.QUERY_TAGS = serialized; - } else { - delete configuration.QUERY_TAGS; - } + if (!this.backend) { + throw new HiveDriverError('DBSQLClient: not connected'); } - - const response = await this.driver.openSession({ - client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8), - ...getInitialNamespaceOptions(request.initialCatalog, request.initialSchema), - configuration, - canUseMultipleCatalogs: true, - }); - - Status.assert(response.status); - const session = new DBSQLSession({ - handle: definedOrError(response.sessionHandle), - context: this, - serverProtocolVersion: response.serverProtocolVersion, - }); + const sessionBackend = await this.backend.openSession(request); + const session = new DBSQLSession({ backend: sessionBackend, context: this }); this.sessions.add(session); return session; } public async close(): Promise { await this.sessions.closeAll(); + await this.backend?.close(); + this.backend = undefined; this.client = undefined; this.connectionProvider = undefined; this.authProvider = undefined; diff --git a/lib/DBSQLOperation.ts b/lib/DBSQLOperation.ts index fe22995d..24f5058d 100644 --- a/lib/DBSQLOperation.ts +++ b/lib/DBSQLOperation.ts @@ -1,4 +1,3 @@ -import { stringify, NIL } from 'uuid'; import { Readable } from 'node:stream'; import IOperation, { FetchOptions, @@ -16,87 +15,54 @@ import { TTableSchema, TSparkDirectResults, TGetResultSetMetadataResp, - TSparkRowSetType, - TCloseOperationResp, - TOperationState, } from '../thrift/TCLIService_types'; import Status from './dto/Status'; import { LogLevel } from './contracts/IDBSQLLogger'; import OperationStateError, { OperationStateErrorCode } from './errors/OperationStateError'; -import IResultsProvider from './result/IResultsProvider'; -import RowSetProvider from './result/RowSetProvider'; -import JsonResultHandler from './result/JsonResultHandler'; -import ArrowResultHandler from './result/ArrowResultHandler'; -import CloudFetchResultHandler from './result/CloudFetchResultHandler'; -import ArrowResultConverter from './result/ArrowResultConverter'; -import ResultSlicer from './result/ResultSlicer'; -import { definedOrError } from './utils'; import { OperationChunksIterator, OperationRowsIterator } from './utils/OperationIterator'; -import HiveDriverError from './errors/HiveDriverError'; import IClientContext from './contracts/IClientContext'; - -interface DBSQLOperationConstructorOptions { - handle: TOperationHandle; - directResults?: TSparkDirectResults; - context: IClientContext; -} - -async function delay(ms?: number): Promise { - return new Promise((resolve) => { - setTimeout(() => { - resolve(); - }, ms); - }); -} +import IOperationBackend from './contracts/IOperationBackend'; +import ThriftOperationBackend from './thrift-backend/ThriftOperationBackend'; + +type DBSQLOperationConstructorOptions = + | { + handle: TOperationHandle; + directResults?: TSparkDirectResults; + context: IClientContext; + } + | { + backend: IOperationBackend; + context: IClientContext; + }; export default class DBSQLOperation implements IOperation { private readonly context: IClientContext; - private readonly operationHandle: TOperationHandle; + private readonly backend: IOperationBackend; public onClose?: () => void; - private readonly _data: RowSetProvider; - - private readonly closeOperation?: TCloseOperationResp; - private closed: boolean = false; private cancelled: boolean = false; - private metadata?: TGetResultSetMetadataResp; - - private metadataPromise?: Promise; - - private state: TOperationState = TOperationState.INITIALIZED_STATE; - - // Once operation is finished or fails - cache status response, because subsequent calls - // to `getOperationStatus()` may fail with irrelevant errors, e.g. HTTP 404 - private operationStatus?: TGetOperationStatusResp; - - private resultHandler?: ResultSlicer; - - constructor({ handle, directResults, context }: DBSQLOperationConstructorOptions) { - this.operationHandle = handle; - this.context = context; - - const useOnlyPrefetchedResults = Boolean(directResults?.closeOperation); - - if (directResults?.operationStatus) { - this.processOperationStatusResponse(directResults.operationStatus); - } - - this.metadata = directResults?.resultSetMetadata; - this._data = new RowSetProvider( - this.context, - this.operationHandle, - [directResults?.resultSet], - useOnlyPrefetchedResults, - ); - this.closeOperation = directResults?.closeOperation; + constructor(options: DBSQLOperationConstructorOptions) { + this.context = options.context; + this.backend = + 'backend' in options + ? options.backend + : new ThriftOperationBackend({ + handle: options.handle, + directResults: options.directResults, + context: options.context, + }); this.context.getLogger().log(LogLevel.debug, `Operation created with id: ${this.id}`); } + public get id() { + return this.backend.id; + } + public iterateChunks(options?: IteratorOptions): IOperationChunksIterator { return new OperationChunksIterator(this, options); } @@ -122,11 +88,6 @@ export default class DBSQLOperation implements IOperation { return Readable.from(iterable, options?.streamOptions); } - public get id() { - const operationId = this.operationHandle?.operationId?.guid; - return operationId ? stringify(operationId) : NIL; - } - /** * Fetches all data * @public @@ -141,8 +102,6 @@ export default class DBSQLOperation implements IOperation { const fetchChunkOptions = { ...options, - // Tell slicer to return raw chunks. We're going to process all of them anyway, - // so no need to additionally buffer and slice chunks returned by server disableBuffering: true, }; @@ -168,47 +127,19 @@ export default class DBSQLOperation implements IOperation { public async fetchChunk(options?: FetchOptions): Promise> { await this.failIfClosed(); - if (!this.operationHandle.hasResultSet) { + if (!this.backend.hasResultSet) { return []; } - await this.waitUntilReady(options); - - const resultHandler = await this.getResultHandler(); + await this.waitUntilReadyThroughBackend(options); await this.failIfClosed(); - // All the library code is Promise-based, however, since Promises are microtasks, - // enqueueing a lot of promises may block macrotasks execution for a while. - // Usually, there are no much microtasks scheduled, however, when fetching query - // results (especially CloudFetch ones) it's quite easy to block event loop for - // long enough to break a lot of things. For example, with CloudFetch, after first - // set of files are downloaded and being processed immediately one by one, event - // loop easily gets blocked for enough time to break connection pool. `http.Agent` - // stops receiving socket events, and marks all sockets invalid on the next attempt - // to use them. See these similar issues that helped to debug this particular case - - // https://github.com/nodejs/node/issues/47130 and https://github.com/node-fetch/node-fetch/issues/1735 - // This simple fix allows to clean up a microtasks queue and allow Node to process - // macrotasks as well, allowing the normal operation of other code. Also, this - // fix is added to `fetchChunk` method because, unlike other methods, `fetchChunk` is - // a potential source of issues described above - await new Promise((resolve) => { - setTimeout(resolve, 0); - }); - const defaultMaxRows = this.context.getConfig().fetchChunkDefaultMaxRows; - - const result = resultHandler.fetchNext({ - limit: options?.maxRows ?? defaultMaxRows, - disableBuffering: options?.disableBuffering, - }); + const limit = options?.maxRows ?? defaultMaxRows; + const result = await this.backend.fetchChunk({ limit, disableBuffering: options?.disableBuffering }); await this.failIfClosed(); - this.context - .getLogger() - .log( - LogLevel.debug, - `Fetched chunk of size: ${options?.maxRows ?? defaultMaxRows} from operation with id: ${this.id}`, - ); + this.context.getLogger().log(LogLevel.debug, `Fetched chunk of size: ${limit} from operation with id: ${this.id}`); return result; } @@ -220,18 +151,7 @@ export default class DBSQLOperation implements IOperation { public async status(progress: boolean = false): Promise { await this.failIfClosed(); this.context.getLogger().log(LogLevel.debug, `Fetching status for operation with id: ${this.id}`); - - if (this.operationStatus) { - return this.operationStatus; - } - - const driver = await this.context.getDriver(); - const response = await driver.getOperationStatus({ - operationHandle: this.operationHandle, - getProgressUpdate: progress, - }); - - return this.processOperationStatusResponse(response); + return this.backend.status(progress); } /** @@ -242,18 +162,8 @@ export default class DBSQLOperation implements IOperation { if (this.closed || this.cancelled) { return Status.success(); } - - this.context.getLogger().log(LogLevel.debug, `Cancelling operation with id: ${this.id}`); - - const driver = await this.context.getDriver(); - const response = await driver.cancelOperation({ - operationHandle: this.operationHandle, - }); - Status.assert(response.status); + const result = await this.backend.cancel(); this.cancelled = true; - const result = new Status(response.status); - - // Cancelled operation becomes unusable, similarly to being closed this.onClose?.(); return result; } @@ -266,63 +176,47 @@ export default class DBSQLOperation implements IOperation { if (this.closed || this.cancelled) { return Status.success(); } - - this.context.getLogger().log(LogLevel.debug, `Closing operation with id: ${this.id}`); - - const driver = await this.context.getDriver(); - const response = - this.closeOperation ?? - (await driver.closeOperation({ - operationHandle: this.operationHandle, - })); - Status.assert(response.status); + const result = await this.backend.close(); this.closed = true; - const result = new Status(response.status); - this.onClose?.(); return result; } public async finished(options?: FinishedOptions): Promise { await this.failIfClosed(); - await this.waitUntilReady(options); + await this.waitUntilReadyThroughBackend(options); } public async hasMoreRows(): Promise { - // If operation is closed or cancelled - we should not try to get data from it if (this.closed || this.cancelled) { return false; } - // Wait for operation to finish before checking for more rows - // This ensures metadata can be fetched successfully - if (this.operationHandle.hasResultSet) { - await this.waitUntilReady(); + if (this.backend.hasResultSet) { + await this.waitUntilReadyThroughBackend(); } - // If we fetched all the data from server - check if there's anything buffered in result handler - const resultHandler = await this.getResultHandler(); - return resultHandler.hasMore(); + return this.backend.hasMore(); } public async getSchema(options?: GetSchemaOptions): Promise { await this.failIfClosed(); - if (!this.operationHandle.hasResultSet) { + if (!this.backend.hasResultSet) { return null; } - await this.waitUntilReady(options); + await this.waitUntilReadyThroughBackend(options); this.context.getLogger().log(LogLevel.debug, `Fetching schema for operation with id: ${this.id}`); - const metadata = await this.fetchMetadata(); + const metadata = await this.backend.getResultMetadata(); return metadata.schema ?? null; } public async getMetadata(): Promise { await this.failIfClosed(); - await this.waitUntilReady(); - return this.fetchMetadata(); + await this.waitUntilReadyThroughBackend(); + return this.backend.getResultMetadata(); } private async failIfClosed(): Promise { @@ -334,151 +228,20 @@ export default class DBSQLOperation implements IOperation { } } - private async waitUntilReady(options?: WaitUntilReadyOptions) { - if (this.state === TOperationState.FINISHED_STATE) { - return; - } - - let isReady = false; - - while (!isReady) { - // eslint-disable-next-line no-await-in-loop - const response = await this.status(Boolean(options?.progress)); - - if (options?.callback) { - // eslint-disable-next-line no-await-in-loop - await Promise.resolve(options.callback(response)); - } - - switch (response.operationState) { - // For these states do nothing and continue waiting - case TOperationState.INITIALIZED_STATE: - case TOperationState.PENDING_STATE: - case TOperationState.RUNNING_STATE: - break; - - // Operation is completed, so exit the loop - case TOperationState.FINISHED_STATE: - isReady = true; - break; - - // Operation was cancelled, so set a flag and exit the loop (throw an error) - case TOperationState.CANCELED_STATE: + private async waitUntilReadyThroughBackend(options?: WaitUntilReadyOptions) { + try { + await this.backend.waitUntilReady(options); + } catch (err) { + // Reflect terminal states back into facade flags so subsequent calls + // short-circuit via failIfClosed(). + if (err instanceof OperationStateError) { + if (err.errorCode === OperationStateErrorCode.Canceled) { this.cancelled = true; - throw new OperationStateError(OperationStateErrorCode.Canceled, response); - - // Operation was closed, so set a flag and exit the loop (throw an error) - case TOperationState.CLOSED_STATE: + } else if (err.errorCode === OperationStateErrorCode.Closed) { this.closed = true; - throw new OperationStateError(OperationStateErrorCode.Closed, response); - - // Error states - throw and exit the loop - case TOperationState.ERROR_STATE: - throw new OperationStateError(OperationStateErrorCode.Error, response); - case TOperationState.TIMEDOUT_STATE: - throw new OperationStateError(OperationStateErrorCode.Timeout, response); - case TOperationState.UKNOWN_STATE: - default: - throw new OperationStateError(OperationStateErrorCode.Unknown, response); - } - - // If not ready yet - make some delay before the next status requests - if (!isReady) { - // eslint-disable-next-line no-await-in-loop - await delay(100); + } } + throw err; } } - - private async fetchMetadata() { - // If metadata is already cached, return it immediately - if (this.metadata) { - return this.metadata; - } - - // If a fetch is already in progress, wait for it to complete - if (this.metadataPromise) { - return this.metadataPromise; - } - - // Start a new fetch and cache the promise to prevent concurrent fetches - this.metadataPromise = (async () => { - const driver = await this.context.getDriver(); - const metadata = await driver.getResultSetMetadata({ - operationHandle: this.operationHandle, - }); - Status.assert(metadata.status); - this.metadata = metadata; - return metadata; - })(); - - try { - return await this.metadataPromise; - } finally { - // Clear the promise once completed (success or failure) - this.metadataPromise = undefined; - } - } - - private async getResultHandler(): Promise> { - const metadata = await this.fetchMetadata(); - const resultFormat = definedOrError(metadata.resultFormat); - - if (!this.resultHandler) { - let resultSource: IResultsProvider> | undefined; - - switch (resultFormat) { - case TSparkRowSetType.COLUMN_BASED_SET: - resultSource = new JsonResultHandler(this.context, this._data, metadata); - break; - case TSparkRowSetType.ARROW_BASED_SET: - resultSource = new ArrowResultConverter( - this.context, - new ArrowResultHandler(this.context, this._data, metadata), - metadata, - ); - break; - case TSparkRowSetType.URL_BASED_SET: - resultSource = new ArrowResultConverter( - this.context, - new CloudFetchResultHandler(this.context, this._data, metadata), - metadata, - ); - break; - // no default - } - - if (resultSource) { - this.resultHandler = new ResultSlicer(this.context, resultSource); - } - } - - if (!this.resultHandler) { - throw new HiveDriverError(`Unsupported result format: ${TSparkRowSetType[resultFormat]}`); - } - - return this.resultHandler; - } - - private processOperationStatusResponse(response: TGetOperationStatusResp) { - Status.assert(response.status); - - this.state = response.operationState ?? this.state; - - if (typeof response.hasResultSet === 'boolean') { - this.operationHandle.hasResultSet = response.hasResultSet; - } - - const isInProgress = [ - TOperationState.INITIALIZED_STATE, - TOperationState.PENDING_STATE, - TOperationState.RUNNING_STATE, - ].includes(this.state); - - if (!isInProgress) { - this.operationStatus = response; - } - - return response; - } } diff --git a/lib/DBSQLSession.ts b/lib/DBSQLSession.ts index 95715e1b..3ca2b73b 100644 --- a/lib/DBSQLSession.ts +++ b/lib/DBSQLSession.ts @@ -2,19 +2,8 @@ import * as fs from 'fs'; import * as path from 'path'; import stream from 'node:stream'; import util from 'node:util'; -import { stringify, NIL } from 'uuid'; -import Int64 from 'node-int64'; import fetch, { HeadersInit } from 'node-fetch'; -import { - TSessionHandle, - TStatus, - TOperationHandle, - TSparkDirectResults, - TSparkArrowTypes, - TSparkParameter, - TProtocolVersion, - TExecuteStatementReq, -} from '../thrift/TCLIService_types'; +import { TSessionHandle, TProtocolVersion } from '../thrift/TCLIService_types'; import IDBSQLSession, { ExecuteStatementOptions, TypeInfoRequest, @@ -31,153 +20,58 @@ import IOperation from './contracts/IOperation'; import DBSQLOperation from './DBSQLOperation'; import Status from './dto/Status'; import InfoValue from './dto/InfoValue'; -import { definedOrError, LZ4, ProtocolVersion, serializeQueryTags } from './utils'; import CloseableCollection from './utils/CloseableCollection'; import { LogLevel } from './contracts/IDBSQLLogger'; import HiveDriverError from './errors/HiveDriverError'; import StagingError from './errors/StagingError'; -import { DBSQLParameter, DBSQLParameterValue } from './DBSQLParameter'; -import ParameterError from './errors/ParameterError'; -import IClientContext, { ClientConfig } from './contracts/IClientContext'; +import IClientContext from './contracts/IClientContext'; +import ISessionBackend from './contracts/ISessionBackend'; +import IOperationBackend from './contracts/IOperationBackend'; +import ThriftSessionBackend from './thrift-backend/ThriftSessionBackend'; // Explicitly promisify a callback-style `pipeline` because `node:stream/promises` is not available in Node 14 const pipeline = util.promisify(stream.pipeline); -interface OperationResponseShape { - status: TStatus; - operationHandle?: TOperationHandle; - directResults?: TSparkDirectResults; -} - -export function numberToInt64(value: number | bigint | Int64): Int64 { - if (value instanceof Int64) { - return value; - } - - if (typeof value === 'bigint') { - const buffer = new ArrayBuffer(BigInt64Array.BYTES_PER_ELEMENT); - const view = new DataView(buffer); - view.setBigInt64(0, value, false); // `false` to use big-endian order - return new Int64(Buffer.from(buffer)); - } - - return new Int64(value); -} - -function getDirectResultsOptions(maxRows: number | bigint | Int64 | null | undefined, config: ClientConfig) { - if (maxRows === null) { - return {}; - } - - return { - getDirectResults: { - maxRows: numberToInt64(maxRows ?? config.directResultsDefaultMaxRows), - }, - }; -} - -function getArrowOptions( - config: ClientConfig, - serverProtocolVersion: TProtocolVersion | undefined | null, -): { - canReadArrowResult: boolean; - useArrowNativeTypes?: TSparkArrowTypes; -} { - const { arrowEnabled = true, useArrowNativeTypes = true } = config; - - if (!arrowEnabled || !ProtocolVersion.supportsArrowMetadata(serverProtocolVersion)) { - return { - canReadArrowResult: false, - }; - } - - return { - canReadArrowResult: true, - useArrowNativeTypes: { - timestampAsArrow: useArrowNativeTypes, - decimalAsArrow: useArrowNativeTypes, - complexTypesAsArrow: useArrowNativeTypes, - // TODO: currently unsupported by `apache-arrow` (see https://github.com/streamlit/streamlit/issues/4489) - intervalTypesAsArrow: false, - }, - }; -} - -function getQueryParameters( - namedParameters?: Record, - ordinalParameters?: Array, -): Array { - const namedParametersProvided = namedParameters !== undefined && Object.keys(namedParameters).length > 0; - const ordinalParametersProvided = ordinalParameters !== undefined && ordinalParameters.length > 0; +// Re-export for back-compat with existing imports. +export { numberToInt64 } from './thrift-backend/ThriftSessionBackend'; - if (namedParametersProvided && ordinalParametersProvided) { - throw new ParameterError('Driver does not support both ordinal and named parameters.'); - } - - if (!namedParametersProvided && !ordinalParametersProvided) { - return []; - } - - const result: Array = []; - - if (namedParameters !== undefined) { - for (const name of Object.keys(namedParameters)) { - const value = namedParameters[name]; - const param = value instanceof DBSQLParameter ? value : new DBSQLParameter({ value }); - result.push(param.toSparkParameter({ name })); - } - } - - if (ordinalParameters !== undefined) { - for (const value of ordinalParameters) { - const param = value instanceof DBSQLParameter ? value : new DBSQLParameter({ value }); - result.push(param.toSparkParameter()); +type DBSQLSessionConstructorOptions = + | { + handle: TSessionHandle; + context: IClientContext; + serverProtocolVersion?: TProtocolVersion; } - } - - return result; -} - -interface DBSQLSessionConstructorOptions { - handle: TSessionHandle; - context: IClientContext; - serverProtocolVersion?: TProtocolVersion; -} + | { + backend: ISessionBackend; + context: IClientContext; + }; export default class DBSQLSession implements IDBSQLSession { private readonly context: IClientContext; - private readonly sessionHandle: TSessionHandle; + private readonly backend: ISessionBackend; private isOpen = true; - private serverProtocolVersion?: TProtocolVersion; - public onClose?: () => void; private operations = new CloseableCollection(); - /** - * Helper method to determine if runAsync should be set for metadata operations - * @private - * @returns true if supported by protocol version, undefined otherwise - */ - private getRunAsyncForMetadataOperations(): boolean | undefined { - return ProtocolVersion.supportsAsyncMetadataOperations(this.serverProtocolVersion) ? true : undefined; - } - - constructor({ handle, context, serverProtocolVersion }: DBSQLSessionConstructorOptions) { - this.sessionHandle = handle; - this.context = context; - // Get the server protocol version from the provided parameter (from TOpenSessionResp) - this.serverProtocolVersion = serverProtocolVersion; + constructor(options: DBSQLSessionConstructorOptions) { + this.context = options.context; + this.backend = + 'backend' in options + ? options.backend + : new ThriftSessionBackend({ + handle: options.handle, + context: options.context, + serverProtocolVersion: options.serverProtocolVersion, + }); this.context.getLogger().log(LogLevel.debug, `Session created with id: ${this.id}`); - this.context.getLogger().log(LogLevel.debug, `Server protocol version: ${this.serverProtocolVersion}`); } public get id() { - const sessionId = this.sessionHandle?.sessionId?.guid; - return sessionId ? stringify(sessionId) : NIL; + return this.backend.id; } /** @@ -190,14 +84,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getInfo(infoType: number): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const operationPromise = driver.getInfo({ - sessionHandle: this.sessionHandle, - infoType, - }); - const response = await this.handleResponse(operationPromise); - Status.assert(response.status); - return new InfoValue(response.infoValue); + const result = await this.backend.getInfo(infoType); + await this.failIfClosed(); + return result; } /** @@ -211,46 +100,11 @@ export default class DBSQLSession implements IDBSQLSession { */ public async executeStatement(statement: string, options: ExecuteStatementOptions = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const request = new TExecuteStatementReq({ - sessionHandle: this.sessionHandle, - statement, - queryTimeout: options.queryTimeout ? numberToInt64(options.queryTimeout) : undefined, - runAsync: true, - ...getDirectResultsOptions(options.maxRows, clientConfig), - ...getArrowOptions(clientConfig, this.serverProtocolVersion), - }); - - if (ProtocolVersion.supportsParameterizedQueries(this.serverProtocolVersion)) { - request.parameters = getQueryParameters(options.namedParameters, options.ordinalParameters); - } - - const serializedQueryTags = serializeQueryTags(options.queryTags); - if (serializedQueryTags !== undefined) { - request.confOverlay = { ...request.confOverlay, query_tags: serializedQueryTags }; - } - - if (ProtocolVersion.supportsCloudFetch(this.serverProtocolVersion)) { - request.canDownloadResult = options.useCloudFetch ?? clientConfig.useCloudFetch; - } - - if (ProtocolVersion.supportsArrowCompression(this.serverProtocolVersion) && request.canDownloadResult !== true) { - request.canDecompressLZ4Result = (options.useLZ4Compression ?? clientConfig.useLZ4Compression) && Boolean(LZ4()); - } + const opBackend = await this.backend.executeStatement(statement, options); + await this.failIfClosed(); + const operation = this.wrapOperation(opBackend); - const operationPromise = driver.executeStatement(request); - const response = await this.handleResponse(operationPromise); - const operation = this.createOperation(response); - - // If `stagingAllowedLocalPath` is provided - assume that operation possibly may be a staging operation. - // To know for sure, fetch metadata and check a `isStagingOperation` flag. If it happens that it wasn't - // a staging operation - not a big deal, we just fetched metadata earlier, but operation is still usable - // and user can get data from it. - // If `stagingAllowedLocalPath` is not provided - don't do anything to the operation. In a case of regular - // operation, everything will work as usual. In a case of staging operation, it will be processed like any - // other query - it will be possible to get data from it as usual, or use other operation methods. + // Staging detection: only run when stagingAllowedLocalPath is provided. if (options.stagingAllowedLocalPath !== undefined) { const metadata = await operation.getMetadata(); if (metadata.isStagingOperation) { @@ -276,7 +130,6 @@ export default class DBSQLSession implements IDBSQLSession { } const row = rows[0] as StagingResponse; - // For REMOVE operation local file is not available, so no need to validate it if (row.localFile !== undefined) { let allowOperation = false; @@ -328,7 +181,6 @@ export default class DBSQLSession implements IDBSQLSession { } const fileStream = fs.createWriteStream(localFile); - // `pipeline` will do all the dirty job for us, including error handling and closing all the streams properly return pipeline(response.body, fileStream); } @@ -337,13 +189,6 @@ export default class DBSQLSession implements IDBSQLSession { const agent = await connectionProvider.getAgent(); const response = await fetch(presignedUrl, { method: 'DELETE', headers, agent }); - // Looks that AWS and Azure have a different behavior of HTTP `DELETE` for non-existing files. - // AWS assumes that - since file already doesn't exist - the goal is achieved, and returns HTTP 200. - // Azure, on the other hand, is somewhat stricter and check if file exists before deleting it. And if - // file doesn't exist - Azure returns HTTP 404. - // - // For us, it's totally okay if file didn't exist before removing. So when we get an HTTP 404 - - // just ignore it and report success. This way we can have a uniform library behavior for all clouds if (!response.ok && response.status !== 404) { throw new StagingError(`HTTP error ${response.status} ${response.statusText}`); } @@ -368,7 +213,6 @@ export default class DBSQLSession implements IDBSQLSession { method: 'PUT', headers: { ...headers, - // This header is required by server 'Content-Length': fileInfo.size.toString(), }, agent, @@ -387,16 +231,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getTypeInfo(request: TypeInfoRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getTypeInfo({ - sessionHandle: this.sessionHandle, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getTypeInfo(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -407,16 +244,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getCatalogs(request: CatalogsRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getCatalogs({ - sessionHandle: this.sessionHandle, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getCatalogs(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -427,18 +257,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getSchemas(request: SchemasRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getSchemas({ - sessionHandle: this.sessionHandle, - catalogName: request.catalogName, - schemaName: request.schemaName, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getSchemas(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -449,20 +270,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getTables(request: TablesRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getTables({ - sessionHandle: this.sessionHandle, - catalogName: request.catalogName, - schemaName: request.schemaName, - tableName: request.tableName, - tableTypes: request.tableTypes, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getTables(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -473,16 +283,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getTableTypes(request: TableTypesRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getTableTypes({ - sessionHandle: this.sessionHandle, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getTableTypes(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -493,20 +296,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getColumns(request: ColumnsRequest = {}): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getColumns({ - sessionHandle: this.sessionHandle, - catalogName: request.catalogName, - schemaName: request.schemaName, - tableName: request.tableName, - columnName: request.columnName, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getColumns(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -517,36 +309,16 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getFunctions(request: FunctionsRequest): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getFunctions({ - sessionHandle: this.sessionHandle, - catalogName: request.catalogName, - schemaName: request.schemaName, - functionName: request.functionName, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getFunctions(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } public async getPrimaryKeys(request: PrimaryKeysRequest): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getPrimaryKeys({ - sessionHandle: this.sessionHandle, - catalogName: request.catalogName, - schemaName: request.schemaName, - tableName: request.tableName, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getPrimaryKeys(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -557,22 +329,9 @@ export default class DBSQLSession implements IDBSQLSession { */ public async getCrossReference(request: CrossReferenceRequest): Promise { await this.failIfClosed(); - const driver = await this.context.getDriver(); - const clientConfig = this.context.getConfig(); - - const operationPromise = driver.getCrossReference({ - sessionHandle: this.sessionHandle, - parentCatalogName: request.parentCatalogName, - parentSchemaName: request.parentSchemaName, - parentTableName: request.parentTableName, - foreignCatalogName: request.foreignCatalogName, - foreignSchemaName: request.foreignSchemaName, - foreignTableName: request.foreignTableName, - runAsync: this.getRunAsyncForMetadataOperations(), - ...getDirectResultsOptions(request.maxRows, clientConfig), - }); - const response = await this.handleResponse(operationPromise); - return this.createOperation(response); + const opBackend = await this.backend.getCrossReference(request); + await this.failIfClosed(); + return this.wrapOperation(opBackend); } /** @@ -585,35 +344,20 @@ export default class DBSQLSession implements IDBSQLSession { return Status.success(); } - // Close owned operations one by one, removing successfully closed ones from the list await this.operations.closeAll(); - const driver = await this.context.getDriver(); - const response = await driver.closeSession({ - sessionHandle: this.sessionHandle, - }); - // check status for being successful - Status.assert(response.status); + const status = await this.backend.close(); - // notify owner connection this.onClose?.(); this.isOpen = false; this.context.getLogger().log(LogLevel.debug, `Session closed with id: ${this.id}`); - return new Status(response.status); + return status; } - private createOperation(response: OperationResponseShape): DBSQLOperation { - Status.assert(response.status); - const handle = definedOrError(response.operationHandle); - const operation = new DBSQLOperation({ - handle, - directResults: response.directResults, - context: this.context, - }); - + private wrapOperation(backend: IOperationBackend): DBSQLOperation { + const operation = new DBSQLOperation({ backend, context: this.context }); this.operations.add(operation); - return operation; } @@ -622,13 +366,4 @@ export default class DBSQLSession implements IDBSQLSession { throw new HiveDriverError('The session was closed or has expired'); } } - - private async handleResponse(requestPromise: Promise): Promise { - // Currently, after being closed sessions remains usable - server will not - // error out when trying to run operations on closed session. So it's - // basically useless to process any errors here - const result = await requestPromise; - await this.failIfClosed(); - return result; - } } diff --git a/lib/contracts/IBackend.ts b/lib/contracts/IBackend.ts new file mode 100644 index 00000000..847c25f7 --- /dev/null +++ b/lib/contracts/IBackend.ts @@ -0,0 +1,15 @@ +import { ConnectionOptions, OpenSessionRequest } from './IDBSQLClient'; +import ISessionBackend from './ISessionBackend'; + +/** + * Top-level backend dispatch handle. One instance per `DBSQLClient`, + * chosen at `connect()` time based on the `useSEA` flag and never + * re-selected per-call. + */ +export default interface IBackend { + connect(options: ConnectionOptions): Promise; + + openSession(request: OpenSessionRequest): Promise; + + close(): Promise; +} diff --git a/lib/contracts/IDBSQLClient.ts b/lib/contracts/IDBSQLClient.ts index 9c0d9670..f4b2a497 100644 --- a/lib/contracts/IDBSQLClient.ts +++ b/lib/contracts/IDBSQLClient.ts @@ -54,6 +54,12 @@ export type ConnectionOptions = { socketTimeout?: number; proxy?: ProxyOptions; enableMetricViewMetadata?: boolean; + /** + * Opt-in flag to dispatch through the Statement Execution API (SEA) backend + * instead of the default Thrift backend. Defaults to `false`. + * @internal Not stable; M0 stub only. + */ + useSEA?: boolean; } & AuthOptions; export interface OpenSessionRequest { diff --git a/lib/contracts/IOperationBackend.ts b/lib/contracts/IOperationBackend.ts new file mode 100644 index 00000000..1a4c1637 --- /dev/null +++ b/lib/contracts/IOperationBackend.ts @@ -0,0 +1,27 @@ +import { TGetOperationStatusResp, TGetResultSetMetadataResp } from '../../thrift/TCLIService_types'; +import Status from '../dto/Status'; +import { WaitUntilReadyOptions } from './IOperation'; + +/** + * What a `DBSQLOperation` needs from its backend. Returned by + * `ISessionBackend.executeStatement` and the metadata methods. + */ +export default interface IOperationBackend { + readonly id: string; + + readonly hasResultSet: boolean; + + fetchChunk(options: { limit: number; disableBuffering?: boolean }): Promise>; + + hasMore(): Promise; + + waitUntilReady(options?: WaitUntilReadyOptions): Promise; + + status(progress: boolean): Promise; + + getResultMetadata(): Promise; + + cancel(): Promise; + + close(): Promise; +} diff --git a/lib/contracts/ISessionBackend.ts b/lib/contracts/ISessionBackend.ts new file mode 100644 index 00000000..eb5fd818 --- /dev/null +++ b/lib/contracts/ISessionBackend.ts @@ -0,0 +1,39 @@ +import IOperationBackend from './IOperationBackend'; +import { + ExecuteStatementOptions, + TypeInfoRequest, + CatalogsRequest, + SchemasRequest, + TablesRequest, + TableTypesRequest, + ColumnsRequest, + FunctionsRequest, + PrimaryKeysRequest, + CrossReferenceRequest, +} from './IDBSQLSession'; +import Status from '../dto/Status'; +import InfoValue from '../dto/InfoValue'; + +/** + * What a `DBSQLSession` needs from its backend. Returned by + * `IBackend.openSession()`. Lifecycle tied to a single `DBSQLSession`. + */ +export default interface ISessionBackend { + readonly id: string; + + getInfo(infoType: number): Promise; + + executeStatement(statement: string, options: ExecuteStatementOptions): Promise; + + getTypeInfo(request: TypeInfoRequest): Promise; + getCatalogs(request: CatalogsRequest): Promise; + getSchemas(request: SchemasRequest): Promise; + getTables(request: TablesRequest): Promise; + getTableTypes(request: TableTypesRequest): Promise; + getColumns(request: ColumnsRequest): Promise; + getFunctions(request: FunctionsRequest): Promise; + getPrimaryKeys(request: PrimaryKeysRequest): Promise; + getCrossReference(request: CrossReferenceRequest): Promise; + + close(): Promise; +} diff --git a/lib/sea/SeaBackend.ts b/lib/sea/SeaBackend.ts new file mode 100644 index 00000000..5815dc05 --- /dev/null +++ b/lib/sea/SeaBackend.ts @@ -0,0 +1,18 @@ +import IBackend from '../contracts/IBackend'; +import ISessionBackend from '../contracts/ISessionBackend'; + +const NOT_IMPLEMENTED = 'SEA backend not implemented yet — wired in sea-napi-binding feature'; + +export default class SeaBackend implements IBackend { + public async connect(): Promise { + throw new Error(NOT_IMPLEMENTED); + } + + public async openSession(): Promise { + throw new Error(NOT_IMPLEMENTED); + } + + public async close(): Promise { + throw new Error(NOT_IMPLEMENTED); + } +} diff --git a/lib/thrift-backend/ThriftBackend.ts b/lib/thrift-backend/ThriftBackend.ts new file mode 100644 index 00000000..5e0e7570 --- /dev/null +++ b/lib/thrift-backend/ThriftBackend.ts @@ -0,0 +1,100 @@ +import Int64 from 'node-int64'; +import IBackend from '../contracts/IBackend'; +import ISessionBackend from '../contracts/ISessionBackend'; +import IClientContext from '../contracts/IClientContext'; +import { OpenSessionRequest } from '../contracts/IDBSQLClient'; +import { TProtocolVersion } from '../../thrift/TCLIService_types'; +import Status from '../dto/Status'; +import { definedOrError, serializeQueryTags } from '../utils'; +import ThriftSessionBackend from './ThriftSessionBackend'; + +function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) { + if (!catalogName && !schemaName) { + return {}; + } + + return { + initialNamespace: { + catalogName, + schemaName, + }, + }; +} + +interface ThriftBackendOptions { + context: IClientContext; + onConnectionEvent: (event: 'error' | 'reconnecting' | 'close' | 'timeout', payload?: unknown) => void; +} + +export default class ThriftBackend implements IBackend { + private readonly context: IClientContext; + + private readonly onConnectionEvent: ThriftBackendOptions['onConnectionEvent']; + + constructor({ context, onConnectionEvent }: ThriftBackendOptions) { + this.context = context; + this.onConnectionEvent = onConnectionEvent; + } + + public async connect(): Promise { + // The connection provider is owned by DBSQLClient (it implements IClientContext). + // We only need to wire the EventEmitter listeners through this backend. + const connectionProvider = await this.context.getConnectionProvider(); + const thriftConnection = await connectionProvider.getThriftConnection(); + + thriftConnection.on('error', (error: Error) => { + this.onConnectionEvent('error', error); + }); + + thriftConnection.on('reconnecting', (params: { delay: number; attempt: number }) => { + this.onConnectionEvent('reconnecting', params); + }); + + thriftConnection.on('close', () => { + this.onConnectionEvent('close'); + }); + + thriftConnection.on('timeout', () => { + this.onConnectionEvent('timeout'); + }); + } + + public async openSession(request: OpenSessionRequest): Promise { + const driver = await this.context.getDriver(); + const config = this.context.getConfig(); + + const configuration = request.configuration ? { ...request.configuration } : {}; + + if (config.enableMetricViewMetadata) { + configuration['spark.sql.thriftserver.metadata.metricview.enabled'] = 'true'; + } + + if (request.queryTags !== undefined) { + const serialized = serializeQueryTags(request.queryTags); + if (serialized) { + configuration.QUERY_TAGS = serialized; + } else { + delete configuration.QUERY_TAGS; + } + } + + const response = await driver.openSession({ + client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8), + ...getInitialNamespaceOptions(request.initialCatalog, request.initialSchema), + configuration, + canUseMultipleCatalogs: true, + }); + + Status.assert(response.status); + return new ThriftSessionBackend({ + handle: definedOrError(response.sessionHandle), + context: this.context, + serverProtocolVersion: response.serverProtocolVersion, + }); + } + + public async close(): Promise { + // DBSQLClient owns the connection lifecycle and clears its own state + // (connectionProvider, authProvider, thrift client) after this returns. + } +} diff --git a/lib/thrift-backend/ThriftOperationBackend.ts b/lib/thrift-backend/ThriftOperationBackend.ts new file mode 100644 index 00000000..e044d374 --- /dev/null +++ b/lib/thrift-backend/ThriftOperationBackend.ts @@ -0,0 +1,291 @@ +import { stringify, NIL } from 'uuid'; +import { + TGetOperationStatusResp, + TOperationHandle, + TSparkDirectResults, + TGetResultSetMetadataResp, + TSparkRowSetType, + TCloseOperationResp, + TOperationState, +} from '../../thrift/TCLIService_types'; +import IOperationBackend from '../contracts/IOperationBackend'; +import IClientContext from '../contracts/IClientContext'; +import Status from '../dto/Status'; +import { LogLevel } from '../contracts/IDBSQLLogger'; +import OperationStateError, { OperationStateErrorCode } from '../errors/OperationStateError'; +import IResultsProvider from '../result/IResultsProvider'; +import RowSetProvider from '../result/RowSetProvider'; +import JsonResultHandler from '../result/JsonResultHandler'; +import ArrowResultHandler from '../result/ArrowResultHandler'; +import CloudFetchResultHandler from '../result/CloudFetchResultHandler'; +import ArrowResultConverter from '../result/ArrowResultConverter'; +import ResultSlicer from '../result/ResultSlicer'; +import { definedOrError } from '../utils'; +import HiveDriverError from '../errors/HiveDriverError'; + +interface ThriftOperationBackendOptions { + handle: TOperationHandle; + directResults?: TSparkDirectResults; + context: IClientContext; +} + +async function delay(ms?: number): Promise { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} + +export default class ThriftOperationBackend implements IOperationBackend { + private readonly context: IClientContext; + + private readonly operationHandle: TOperationHandle; + + private readonly _data: RowSetProvider; + + private readonly closeOperation?: TCloseOperationResp; + + private metadata?: TGetResultSetMetadataResp; + + private metadataPromise?: Promise; + + private state: TOperationState = TOperationState.INITIALIZED_STATE; + + private operationStatus?: TGetOperationStatusResp; + + private resultHandler?: ResultSlicer; + + constructor({ handle, directResults, context }: ThriftOperationBackendOptions) { + this.operationHandle = handle; + this.context = context; + + const useOnlyPrefetchedResults = Boolean(directResults?.closeOperation); + + if (directResults?.operationStatus) { + this.processOperationStatusResponse(directResults.operationStatus); + } + + this.metadata = directResults?.resultSetMetadata; + this._data = new RowSetProvider( + this.context, + this.operationHandle, + [directResults?.resultSet], + useOnlyPrefetchedResults, + ); + this.closeOperation = directResults?.closeOperation; + } + + public get id(): string { + const operationId = this.operationHandle?.operationId?.guid; + return operationId ? stringify(operationId) : NIL; + } + + public get hasResultSet(): boolean { + return Boolean(this.operationHandle.hasResultSet); + } + + public async fetchChunk({ + limit, + disableBuffering, + }: { + limit: number; + disableBuffering?: boolean; + }): Promise> { + const resultHandler = await this.getResultHandler(); + + // All the library code is Promise-based, however, since Promises are microtasks, + // enqueueing a lot of promises may block macrotasks execution for a while. + // Usually, there are no much microtasks scheduled, however, when fetching query + // results (especially CloudFetch ones) it's quite easy to block event loop for + // long enough to break a lot of things. For example, with CloudFetch, after first + // set of files are downloaded and being processed immediately one by one, event + // loop easily gets blocked for enough time to break connection pool. `http.Agent` + // stops receiving socket events, and marks all sockets invalid on the next attempt + // to use them. See these similar issues that helped to debug this particular case - + // https://github.com/nodejs/node/issues/47130 and https://github.com/node-fetch/node-fetch/issues/1735 + await new Promise((resolve) => { + setTimeout(resolve, 0); + }); + + return resultHandler.fetchNext({ limit, disableBuffering }); + } + + public async hasMore(): Promise { + const resultHandler = await this.getResultHandler(); + return resultHandler.hasMore(); + } + + public async status(progress: boolean): Promise { + if (this.operationStatus) { + return this.operationStatus; + } + + const driver = await this.context.getDriver(); + const response = await driver.getOperationStatus({ + operationHandle: this.operationHandle, + getProgressUpdate: progress, + }); + + return this.processOperationStatusResponse(response); + } + + public async waitUntilReady(options?: { + progress?: boolean; + callback?: (progress: TGetOperationStatusResp) => unknown; + }): Promise { + if (this.state === TOperationState.FINISHED_STATE) { + return; + } + + let isReady = false; + + while (!isReady) { + // eslint-disable-next-line no-await-in-loop + const response = await this.status(Boolean(options?.progress)); + + if (options?.callback) { + // eslint-disable-next-line no-await-in-loop + await Promise.resolve(options.callback(response)); + } + + switch (response.operationState) { + case TOperationState.INITIALIZED_STATE: + case TOperationState.PENDING_STATE: + case TOperationState.RUNNING_STATE: + break; + + case TOperationState.FINISHED_STATE: + isReady = true; + break; + + case TOperationState.CANCELED_STATE: + throw new OperationStateError(OperationStateErrorCode.Canceled, response); + + case TOperationState.CLOSED_STATE: + throw new OperationStateError(OperationStateErrorCode.Closed, response); + + case TOperationState.ERROR_STATE: + throw new OperationStateError(OperationStateErrorCode.Error, response); + case TOperationState.TIMEDOUT_STATE: + throw new OperationStateError(OperationStateErrorCode.Timeout, response); + case TOperationState.UKNOWN_STATE: + default: + throw new OperationStateError(OperationStateErrorCode.Unknown, response); + } + + if (!isReady) { + // eslint-disable-next-line no-await-in-loop + await delay(100); + } + } + } + + public async getResultMetadata(): Promise { + if (this.metadata) { + return this.metadata; + } + + if (this.metadataPromise) { + return this.metadataPromise; + } + + this.metadataPromise = (async () => { + const driver = await this.context.getDriver(); + const metadata = await driver.getResultSetMetadata({ + operationHandle: this.operationHandle, + }); + Status.assert(metadata.status); + this.metadata = metadata; + return metadata; + })(); + + try { + return await this.metadataPromise; + } finally { + this.metadataPromise = undefined; + } + } + + public async cancel(): Promise { + this.context.getLogger().log(LogLevel.debug, `Cancelling operation with id: ${this.id}`); + const driver = await this.context.getDriver(); + const response = await driver.cancelOperation({ + operationHandle: this.operationHandle, + }); + Status.assert(response.status); + return new Status(response.status); + } + + public async close(): Promise { + this.context.getLogger().log(LogLevel.debug, `Closing operation with id: ${this.id}`); + const driver = await this.context.getDriver(); + const response = + this.closeOperation ?? + (await driver.closeOperation({ + operationHandle: this.operationHandle, + })); + Status.assert(response.status); + return new Status(response.status); + } + + private async getResultHandler(): Promise> { + const metadata = await this.getResultMetadata(); + const resultFormat = definedOrError(metadata.resultFormat); + + if (!this.resultHandler) { + let resultSource: IResultsProvider> | undefined; + + switch (resultFormat) { + case TSparkRowSetType.COLUMN_BASED_SET: + resultSource = new JsonResultHandler(this.context, this._data, metadata); + break; + case TSparkRowSetType.ARROW_BASED_SET: + resultSource = new ArrowResultConverter( + this.context, + new ArrowResultHandler(this.context, this._data, metadata), + metadata, + ); + break; + case TSparkRowSetType.URL_BASED_SET: + resultSource = new ArrowResultConverter( + this.context, + new CloudFetchResultHandler(this.context, this._data, metadata), + metadata, + ); + break; + // no default + } + + if (resultSource) { + this.resultHandler = new ResultSlicer(this.context, resultSource); + } + } + + if (!this.resultHandler) { + throw new HiveDriverError(`Unsupported result format: ${TSparkRowSetType[resultFormat]}`); + } + + return this.resultHandler; + } + + private processOperationStatusResponse(response: TGetOperationStatusResp) { + Status.assert(response.status); + + this.state = response.operationState ?? this.state; + + if (typeof response.hasResultSet === 'boolean') { + this.operationHandle.hasResultSet = response.hasResultSet; + } + + const isInProgress = [ + TOperationState.INITIALIZED_STATE, + TOperationState.PENDING_STATE, + TOperationState.RUNNING_STATE, + ].includes(this.state); + + if (!isInProgress) { + this.operationStatus = response; + } + + return response; + } +} diff --git a/lib/thrift-backend/ThriftSessionBackend.ts b/lib/thrift-backend/ThriftSessionBackend.ts new file mode 100644 index 00000000..916eb221 --- /dev/null +++ b/lib/thrift-backend/ThriftSessionBackend.ts @@ -0,0 +1,331 @@ +import { stringify, NIL } from 'uuid'; +import Int64 from 'node-int64'; +import { + TSessionHandle, + TStatus, + TOperationHandle, + TSparkDirectResults, + TSparkArrowTypes, + TSparkParameter, + TProtocolVersion, + TExecuteStatementReq, +} from '../../thrift/TCLIService_types'; +import ISessionBackend from '../contracts/ISessionBackend'; +import IOperationBackend from '../contracts/IOperationBackend'; +import IClientContext, { ClientConfig } from '../contracts/IClientContext'; +import { + ExecuteStatementOptions, + TypeInfoRequest, + CatalogsRequest, + SchemasRequest, + TablesRequest, + TableTypesRequest, + ColumnsRequest, + FunctionsRequest, + PrimaryKeysRequest, + CrossReferenceRequest, +} from '../contracts/IDBSQLSession'; +import Status from '../dto/Status'; +import InfoValue from '../dto/InfoValue'; +import { definedOrError, LZ4, ProtocolVersion, serializeQueryTags } from '../utils'; +import ParameterError from '../errors/ParameterError'; +import { DBSQLParameter, DBSQLParameterValue } from '../DBSQLParameter'; +import ThriftOperationBackend from './ThriftOperationBackend'; + +interface OperationResponseShape { + status: TStatus; + operationHandle?: TOperationHandle; + directResults?: TSparkDirectResults; +} + +export function numberToInt64(value: number | bigint | Int64): Int64 { + if (value instanceof Int64) { + return value; + } + + if (typeof value === 'bigint') { + const buffer = new ArrayBuffer(BigInt64Array.BYTES_PER_ELEMENT); + const view = new DataView(buffer); + view.setBigInt64(0, value, false); // `false` to use big-endian order + return new Int64(Buffer.from(buffer)); + } + + return new Int64(value); +} + +function getDirectResultsOptions(maxRows: number | bigint | Int64 | null | undefined, config: ClientConfig) { + if (maxRows === null) { + return {}; + } + + return { + getDirectResults: { + maxRows: numberToInt64(maxRows ?? config.directResultsDefaultMaxRows), + }, + }; +} + +function getArrowOptions( + config: ClientConfig, + serverProtocolVersion: TProtocolVersion | undefined | null, +): { + canReadArrowResult: boolean; + useArrowNativeTypes?: TSparkArrowTypes; +} { + const { arrowEnabled = true, useArrowNativeTypes = true } = config; + + if (!arrowEnabled || !ProtocolVersion.supportsArrowMetadata(serverProtocolVersion)) { + return { + canReadArrowResult: false, + }; + } + + return { + canReadArrowResult: true, + useArrowNativeTypes: { + timestampAsArrow: useArrowNativeTypes, + decimalAsArrow: useArrowNativeTypes, + complexTypesAsArrow: useArrowNativeTypes, + intervalTypesAsArrow: false, + }, + }; +} + +function getQueryParameters( + namedParameters?: Record, + ordinalParameters?: Array, +): Array { + const namedParametersProvided = namedParameters !== undefined && Object.keys(namedParameters).length > 0; + const ordinalParametersProvided = ordinalParameters !== undefined && ordinalParameters.length > 0; + + if (namedParametersProvided && ordinalParametersProvided) { + throw new ParameterError('Driver does not support both ordinal and named parameters.'); + } + + if (!namedParametersProvided && !ordinalParametersProvided) { + return []; + } + + const result: Array = []; + + if (namedParameters !== undefined) { + for (const name of Object.keys(namedParameters)) { + const value = namedParameters[name]; + const param = value instanceof DBSQLParameter ? value : new DBSQLParameter({ value }); + result.push(param.toSparkParameter({ name })); + } + } + + if (ordinalParameters !== undefined) { + for (const value of ordinalParameters) { + const param = value instanceof DBSQLParameter ? value : new DBSQLParameter({ value }); + result.push(param.toSparkParameter()); + } + } + + return result; +} + +interface ThriftSessionBackendOptions { + handle: TSessionHandle; + context: IClientContext; + serverProtocolVersion?: TProtocolVersion; +} + +export default class ThriftSessionBackend implements ISessionBackend { + private readonly context: IClientContext; + + private readonly sessionHandle: TSessionHandle; + + private readonly serverProtocolVersion?: TProtocolVersion; + + constructor({ handle, context, serverProtocolVersion }: ThriftSessionBackendOptions) { + this.sessionHandle = handle; + this.context = context; + this.serverProtocolVersion = serverProtocolVersion; + } + + private getRunAsyncForMetadataOperations(): boolean | undefined { + return ProtocolVersion.supportsAsyncMetadataOperations(this.serverProtocolVersion) ? true : undefined; + } + + public get id(): string { + const sessionId = this.sessionHandle?.sessionId?.guid; + return sessionId ? stringify(sessionId) : NIL; + } + + public async getInfo(infoType: number): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getInfo({ + sessionHandle: this.sessionHandle, + infoType, + }); + Status.assert(response.status); + return new InfoValue(response.infoValue); + } + + public async executeStatement(statement: string, options: ExecuteStatementOptions): Promise { + const driver = await this.context.getDriver(); + const clientConfig = this.context.getConfig(); + + const request = new TExecuteStatementReq({ + sessionHandle: this.sessionHandle, + statement, + queryTimeout: options.queryTimeout ? numberToInt64(options.queryTimeout) : undefined, + runAsync: true, + ...getDirectResultsOptions(options.maxRows, clientConfig), + ...getArrowOptions(clientConfig, this.serverProtocolVersion), + }); + + if (ProtocolVersion.supportsParameterizedQueries(this.serverProtocolVersion)) { + request.parameters = getQueryParameters(options.namedParameters, options.ordinalParameters); + } + + const serializedQueryTags = serializeQueryTags(options.queryTags); + if (serializedQueryTags !== undefined) { + request.confOverlay = { ...request.confOverlay, query_tags: serializedQueryTags }; + } + + if (ProtocolVersion.supportsCloudFetch(this.serverProtocolVersion)) { + request.canDownloadResult = options.useCloudFetch ?? clientConfig.useCloudFetch; + } + + if (ProtocolVersion.supportsArrowCompression(this.serverProtocolVersion) && request.canDownloadResult !== true) { + request.canDecompressLZ4Result = (options.useLZ4Compression ?? clientConfig.useLZ4Compression) && Boolean(LZ4()); + } + + const response = await driver.executeStatement(request); + return this.createOperationBackend(response); + } + + public async getTypeInfo(request: TypeInfoRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getTypeInfo({ + sessionHandle: this.sessionHandle, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getCatalogs(request: CatalogsRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getCatalogs({ + sessionHandle: this.sessionHandle, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getSchemas(request: SchemasRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getSchemas({ + sessionHandle: this.sessionHandle, + catalogName: request.catalogName, + schemaName: request.schemaName, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getTables(request: TablesRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getTables({ + sessionHandle: this.sessionHandle, + catalogName: request.catalogName, + schemaName: request.schemaName, + tableName: request.tableName, + tableTypes: request.tableTypes, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getTableTypes(request: TableTypesRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getTableTypes({ + sessionHandle: this.sessionHandle, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getColumns(request: ColumnsRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getColumns({ + sessionHandle: this.sessionHandle, + catalogName: request.catalogName, + schemaName: request.schemaName, + tableName: request.tableName, + columnName: request.columnName, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getFunctions(request: FunctionsRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getFunctions({ + sessionHandle: this.sessionHandle, + catalogName: request.catalogName, + schemaName: request.schemaName, + functionName: request.functionName, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getPrimaryKeys(request: PrimaryKeysRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getPrimaryKeys({ + sessionHandle: this.sessionHandle, + catalogName: request.catalogName, + schemaName: request.schemaName, + tableName: request.tableName, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async getCrossReference(request: CrossReferenceRequest): Promise { + const driver = await this.context.getDriver(); + const response = await driver.getCrossReference({ + sessionHandle: this.sessionHandle, + parentCatalogName: request.parentCatalogName, + parentSchemaName: request.parentSchemaName, + parentTableName: request.parentTableName, + foreignCatalogName: request.foreignCatalogName, + foreignSchemaName: request.foreignSchemaName, + foreignTableName: request.foreignTableName, + runAsync: this.getRunAsyncForMetadataOperations(), + ...getDirectResultsOptions(request.maxRows, this.context.getConfig()), + }); + return this.createOperationBackend(response); + } + + public async close(): Promise { + const driver = await this.context.getDriver(); + const response = await driver.closeSession({ + sessionHandle: this.sessionHandle, + }); + Status.assert(response.status); + return new Status(response.status); + } + + private createOperationBackend(response: OperationResponseShape): IOperationBackend { + Status.assert(response.status); + const handle = definedOrError(response.operationHandle); + return new ThriftOperationBackend({ + handle, + directResults: response.directResults, + context: this.context, + }); + } +} diff --git a/tests/unit/DBSQLClient.test.ts b/tests/unit/DBSQLClient.test.ts index 4c0a3a34..8c3e64ce 100644 --- a/tests/unit/DBSQLClient.test.ts +++ b/tests/unit/DBSQLClient.test.ts @@ -2,6 +2,7 @@ import { expect, AssertionError } from 'chai'; import sinon from 'sinon'; import DBSQLClient, { ThriftLibrary } from '../../lib/DBSQLClient'; import DBSQLSession from '../../lib/DBSQLSession'; +import ThriftBackend from '../../lib/thrift-backend/ThriftBackend'; import PlainHttpAuthentication from '../../lib/connection/auth/PlainHttpAuthentication'; import DatabricksOAuth from '../../lib/connection/auth/DatabricksOAuth'; @@ -25,6 +26,19 @@ const connectOptions = { token: 'dapi********************************', } satisfies ConnectionOptions; +// Test helper: build a DBSQLClient with `getClient` stubbed to return the given +// ThriftClient stub, and pre-seed `client['backend']` with a ThriftBackend. +// Used to avoid 12 copies of the same 4-line setup across the openSession tests. +function makeStubbedClient(thriftClient: ThriftClientStub = new ThriftClientStub()): { + client: DBSQLClient; + thriftClient: ThriftClientStub; +} { + const client = new DBSQLClient(); + sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + client['backend'] = new ThriftBackend({ context: client, onConnectionEvent: () => {} }); + return { client, thriftClient }; +} + describe('DBSQLClient.connect', () => { it('should prepend "/" to path if it is missing', async () => { const client = new DBSQLClient(); @@ -103,18 +117,14 @@ describe('DBSQLClient.connect', () => { describe('DBSQLClient.openSession', () => { it('should successfully open session', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client } = makeStubbedClient(); const session = await client.openSession(); expect(session).instanceOf(DBSQLSession); }); it('should use initial namespace options', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); case1: { const initialCatalog = 'catalog1'; @@ -144,6 +154,7 @@ describe('DBSQLClient.openSession', () => { it('should throw an exception when not connected', async () => { const client = new DBSQLClient(); + client['backend'] = undefined; client['connectionProvider'] = undefined; try { @@ -158,15 +169,13 @@ describe('DBSQLClient.openSession', () => { }); it('should correctly pass server protocol version to session', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); // Test with default protocol version (SPARK_CLI_SERVICE_PROTOCOL_V8) { const session = await client.openSession(); expect(session).instanceOf(DBSQLSession); - expect((session as DBSQLSession)['serverProtocolVersion']).to.equal( + expect(((session as DBSQLSession)['backend'] as any)['serverProtocolVersion']).to.equal( TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, ); } @@ -179,16 +188,14 @@ describe('DBSQLClient.openSession', () => { const session = await client.openSession(); expect(session).instanceOf(DBSQLSession); - expect((session as DBSQLSession)['serverProtocolVersion']).to.equal( + expect(((session as DBSQLSession)['backend'] as any)['serverProtocolVersion']).to.equal( TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, ); } }); it('should pass session configuration to OpenSessionReq', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); const configuration = { QUERY_TAGS: 'team:engineering', ansi_mode: 'true' }; await client.openSession({ configuration }); @@ -196,9 +203,7 @@ describe('DBSQLClient.openSession', () => { }); it('should affect session behavior based on protocol version', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); // With protocol version V6 - should support async metadata operations { @@ -360,6 +365,7 @@ describe('DBSQLClient.close', () => { client['client'] = thriftClient; client['connectionProvider'] = new ConnectionProviderStub(); client['authProvider'] = new AuthProviderStub(); + client['backend'] = new ThriftBackend({ context: client, onConnectionEvent: () => {} }); const session = await client.openSession(); if (!(session instanceof DBSQLSession)) { @@ -583,9 +589,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should inject session parameter when enableMetricViewMetadata is true', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.connect({ ...connectOptions, enableMetricViewMetadata: true }); await client.openSession(); @@ -597,9 +601,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should not inject session parameter when enableMetricViewMetadata is false', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.connect({ ...connectOptions, enableMetricViewMetadata: false }); await client.openSession(); @@ -610,9 +612,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should not inject session parameter when enableMetricViewMetadata is not set', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.connect(connectOptions); await client.openSession(); @@ -623,9 +623,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should preserve user-provided session configuration', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.connect({ ...connectOptions, enableMetricViewMetadata: true }); const userConfig = { QUERY_TAGS: 'team:engineering', ansi_mode: 'true' }; @@ -638,9 +636,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should serialize queryTags dict and set in session configuration', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.openSession({ queryTags: { team: 'data-eng', project: 'etl' }, @@ -652,9 +648,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should let queryTags take precedence over configuration.QUERY_TAGS', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.openSession({ queryTags: { team: 'new-team' }, @@ -668,9 +662,7 @@ describe('DBSQLClient.enableMetricViewMetadata', () => { }); it('should remove QUERY_TAGS from configuration when queryTags is empty', async () => { - const client = new DBSQLClient(); - const thriftClient = new ThriftClientStub(); - sinon.stub(client, 'getClient').returns(Promise.resolve(thriftClient)); + const { client, thriftClient } = makeStubbedClient(); await client.openSession({ queryTags: {}, diff --git a/tests/unit/DBSQLOperation.test.ts b/tests/unit/DBSQLOperation.test.ts index b5f142ba..0c1872e8 100644 --- a/tests/unit/DBSQLOperation.test.ts +++ b/tests/unit/DBSQLOperation.test.ts @@ -49,8 +49,8 @@ describe('DBSQLOperation', () => { const context = new ClientContextStub(); const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); - expect(operation['state']).to.equal(TOperationState.INITIALIZED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.INITIALIZED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; }); it('should pick up state from directResults', async () => { @@ -67,8 +67,8 @@ describe('DBSQLOperation', () => { }, }); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; }); it('should fetch status and update internal state', async () => { @@ -79,15 +79,15 @@ describe('DBSQLOperation', () => { const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: false }), context }); - expect(operation['state']).to.equal(TOperationState.INITIALIZED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.false; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.INITIALIZED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.false; const status = await operation.status(); expect(driver.getOperationStatus.called).to.be.true; expect(status.operationState).to.equal(TOperationState.FINISHED_STATE); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; }); it('should request progress', async () => { @@ -110,8 +110,8 @@ describe('DBSQLOperation', () => { const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: false }), context }); - expect(operation['state']).to.equal(TOperationState.INITIALIZED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.false; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.INITIALIZED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.false; // First call - should fetch data and cache driver.getOperationStatusResp = { @@ -122,8 +122,8 @@ describe('DBSQLOperation', () => { expect(driver.getOperationStatus.callCount).to.equal(1); expect(status1.operationState).to.equal(TOperationState.FINISHED_STATE); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; // Second call - should return cached data driver.getOperationStatusResp = { @@ -134,8 +134,8 @@ describe('DBSQLOperation', () => { expect(driver.getOperationStatus.callCount).to.equal(1); expect(status2.operationState).to.equal(TOperationState.FINISHED_STATE); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; }); it('should fetch status if directResults status is not finished', async () => { @@ -156,15 +156,15 @@ describe('DBSQLOperation', () => { }, }); - expect(operation['state']).to.equal(TOperationState.RUNNING_STATE); // from directResults - expect(operation['operationHandle'].hasResultSet).to.be.false; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.RUNNING_STATE); // from directResults + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.false; const status = await operation.status(false); expect(driver.getOperationStatus.called).to.be.true; expect(status.operationState).to.equal(TOperationState.FINISHED_STATE); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.true; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.true; }); it('should not fetch status if directResults status is finished', async () => { @@ -185,15 +185,15 @@ describe('DBSQLOperation', () => { }, }); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); // from directResults - expect(operation['operationHandle'].hasResultSet).to.be.false; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); // from directResults + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.false; const status = await operation.status(false); expect(driver.getOperationStatus.called).to.be.false; expect(status.operationState).to.equal(TOperationState.FINISHED_STATE); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); - expect(operation['operationHandle'].hasResultSet).to.be.false; + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['operationHandle'].hasResultSet).to.be.false; }); it('should throw an error in case of a status error', async () => { @@ -439,12 +439,12 @@ describe('DBSQLOperation', () => { const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); - expect(operation['state']).to.equal(TOperationState.INITIALIZED_STATE); + expect((operation['backend'] as any)['state']).to.equal(TOperationState.INITIALIZED_STATE); await operation.finished(); expect(getOperationStatusStub.callCount).to.be.equal(attemptsUntilFinished); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); }); }, ); @@ -603,7 +603,7 @@ describe('DBSQLOperation', () => { expect(getOperationStatusStub.called).to.be.true; expect(schema).to.deep.equal(context.driver.getResultSetMetadataResp.schema); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); }); it('should request progress', async () => { @@ -752,7 +752,7 @@ describe('DBSQLOperation', () => { driver.getResultSetMetadata.resetHistory(); const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); - const resultHandler = await operation['getResultHandler'](); + const resultHandler = await (operation['backend'] as any)['getResultHandler'](); expect(driver.getResultSetMetadata.called).to.be.true; expect(resultHandler).to.be.instanceOf(ResultSlicer); expect(resultHandler['source']).to.be.instanceOf(JsonResultHandler); @@ -763,7 +763,7 @@ describe('DBSQLOperation', () => { driver.getResultSetMetadata.resetHistory(); const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); - const resultHandler = await operation['getResultHandler'](); + const resultHandler = await (operation['backend'] as any)['getResultHandler'](); expect(driver.getResultSetMetadata.called).to.be.true; expect(resultHandler).to.be.instanceOf(ResultSlicer); expect(resultHandler['source']).to.be.instanceOf(ArrowResultConverter); @@ -778,7 +778,7 @@ describe('DBSQLOperation', () => { driver.getResultSetMetadata.resetHistory(); const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); - const resultHandler = await operation['getResultHandler'](); + const resultHandler = await (operation['backend'] as any)['getResultHandler'](); expect(driver.getResultSetMetadata.called).to.be.true; expect(resultHandler).to.be.instanceOf(ResultSlicer); expect(resultHandler['source']).to.be.instanceOf(ArrowResultConverter); @@ -828,7 +828,7 @@ describe('DBSQLOperation', () => { expect(getOperationStatusStub.called).to.be.true; expect(results).to.deep.equal([]); - expect(operation['state']).to.equal(TOperationState.FINISHED_STATE); + expect((operation['backend'] as any)['state']).to.equal(TOperationState.FINISHED_STATE); }); it('should request progress', async () => { @@ -1041,10 +1041,10 @@ describe('DBSQLOperation', () => { const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.undefined; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.undefined; await operation.fetchChunk({ disableBuffering: true }); expect(await operation.hasMoreRows()).to.be.false; - expect(operation['_data']['hasMoreRowsFlag']).to.be.false; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.false; }); it('should return False if operation was closed', async () => { @@ -1086,10 +1086,10 @@ describe('DBSQLOperation', () => { const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.undefined; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.undefined; await operation.fetchChunk({ disableBuffering: true }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.true; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.true; }); it('should return True if hasMoreRows flag is False but there is actual data', async () => { @@ -1101,10 +1101,10 @@ describe('DBSQLOperation', () => { const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.undefined; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.undefined; await operation.fetchChunk({ disableBuffering: true }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.true; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.true; }); it('should return True if hasMoreRows flag is unset but there is actual data', async () => { @@ -1116,10 +1116,10 @@ describe('DBSQLOperation', () => { const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.undefined; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.undefined; await operation.fetchChunk({ disableBuffering: true }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.true; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.true; }); it('should return False if hasMoreRows flag is False and there is no data', async () => { @@ -1132,10 +1132,10 @@ describe('DBSQLOperation', () => { const operation = new DBSQLOperation({ handle: operationHandleStub({ hasResultSet: true }), context }); expect(await operation.hasMoreRows()).to.be.true; - expect(operation['_data']['hasMoreRowsFlag']).to.be.undefined; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.undefined; await operation.fetchChunk({ disableBuffering: true }); expect(await operation.hasMoreRows()).to.be.false; - expect(operation['_data']['hasMoreRowsFlag']).to.be.false; + expect((operation['backend'] as any)['_data']['hasMoreRowsFlag']).to.be.false; }); });