diff --git a/openapi.yaml b/openapi.yaml index 72d0e8c..de14cd6 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -7551,6 +7551,134 @@ paths: schema: description: Training session ID type: string + /rl/training-sessions/{session_id}/operations/custom-forward-backward: + post: + summary: Custom forward-backward pass + description: Submits a forward-backward pass driven by externally computed gradients of the loss with respect to per-token log-probabilities. + operationId: customForwardBackward + tags: [RL] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RL.CustomForwardBackwardBody' + required: true + responses: + "200": + description: "" + content: + application/json: + schema: + $ref: '#/components/schemas/RL.CustomForwardBackwardOperation' + default: + description: An unexpected error response. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorData' + parameters: + - name: session_id + in: path + required: true + schema: + description: Training session ID + type: string + /rl/training-sessions/{session_id}/operations/custom-forward-backward/{operation_id}: + get: + summary: Get custom forward-backward operation + description: Retrieves the current status and result of a custom forward-backward operation. + operationId: getCustomForwardBackwardOperation + tags: [RL] + responses: + "200": + description: Custom forward-backward operation details + content: + application/json: + schema: + $ref: '#/components/schemas/RL.CustomForwardBackwardOperation' + default: + description: An unexpected error response. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorData' + parameters: + - name: session_id + in: path + required: true + schema: + description: Training session ID + type: string + - name: operation_id + in: path + required: true + schema: + description: Operation ID + type: string + /rl/training-sessions/{session_id}/operations/forward: + post: + summary: Forward pass + description: Submits a forward operation that will asynchronously run a no-grad forward pass and return per-token log-probabilities for each sample. + operationId: forward + tags: [RL] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RL.ForwardBody' + required: true + responses: + "200": + description: "" + content: + application/json: + schema: + $ref: '#/components/schemas/RL.ForwardOperation' + default: + description: An unexpected error response. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorData' + parameters: + - name: session_id + in: path + required: true + schema: + description: Training session ID + type: string + /rl/training-sessions/{session_id}/operations/forward/{operation_id}: + get: + summary: Get forward operation + description: Retrieves the current status and result of a forward operation. + operationId: getForwardOperation + tags: [RL] + responses: + "200": + description: Forward operation details + content: + application/json: + schema: + $ref: '#/components/schemas/RL.ForwardOperation' + default: + description: An unexpected error response. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorData' + parameters: + - name: session_id + in: path + required: true + schema: + description: Training session ID + type: string + - name: operation_id + in: path + required: true + schema: + description: Operation ID + type: string /rl/training-sessions/{session_id}/operations/optim-step: post: summary: Optimizer step @@ -8241,6 +8369,118 @@ components: additionalProperties: type: number format: double + RL.CustomForwardBackwardBody: + type: object + description: Request body for a custom forward-backward pass. + required: + - samples + - gradients + properties: + samples: + type: array + description: Batch of training samples + items: + $ref: '#/components/schemas/RL.TrainingSample' + gradients: + type: array + description: Per-sample per-token gradients of the loss with respect to log-probabilities + items: + $ref: '#/components/schemas/RL.TargetLogprobGradients' + RL.CustomForwardBackwardOperation: + type: object + description: Async custom forward-backward pass operation + required: + - id + - status + properties: + id: + type: string + example: 550e8400-e29b-41d4-a716-446655440000 + description: Operation ID + status: + $ref: '#/components/schemas/RL.TrainingOperationStatus' + example: TRAINING_OPERATION_STATUS_PENDING + description: Operation status + output: + $ref: '#/components/schemas/RL.CustomForwardBackwardResult' + description: Result on success + error: + $ref: '#/components/schemas/RL.TrainingOperationError' + description: Error details on failure + RL.CustomForwardBackwardResult: + type: object + description: Result of a custom forward-backward pass operation + RL.ForwardBody: + type: object + description: Request body for a forward pass. + required: + - samples + properties: + samples: + type: array + description: Batch of training samples for which to compute per-token log-probabilities + items: + $ref: '#/components/schemas/RL.TrainingSample' + RL.ForwardOperation: + type: object + description: Async forward pass operation + required: + - id + - status + properties: + id: + type: string + example: 550e8400-e29b-41d4-a716-446655440000 + description: Operation ID + status: + $ref: '#/components/schemas/RL.TrainingOperationStatus' + example: TRAINING_OPERATION_STATUS_PENDING + description: Operation status + output: + $ref: '#/components/schemas/RL.ForwardResult' + description: Result on success + error: + $ref: '#/components/schemas/RL.TrainingOperationError' + description: Error details on failure + RL.ForwardResult: + type: object + description: Result of a forward pass operation + properties: + logprobs: + type: array + description: Per-sample per-token log-probabilities + items: + $ref: '#/components/schemas/RL.TargetLogprobs' + RL.TargetLogprobs: + type: object + description: Per-token log-probabilities from the target model + required: + - data + properties: + data: + type: array + description: Float array of per-token log probabilities + example: [-1.2, -0.8, -1.5, -0.9, -1.1] + items: + type: number + format: float + RL.TargetLogprobGradients: + type: object + description: Per-token gradients of the loss with respect to target log-probabilities + required: + - data + properties: + data: + type: array + description: Float array of per-token gradients (d loss / d log p) + example: [-0.1, 0.05, -0.08, 0.12, -0.03] + items: + type: number + format: float + dtype: + $ref: '#/components/schemas/RL.DType' + example: D_TYPE_FLOAT32 + description: Data type of the float array RL.TrainingOperationError: type: object description: Error details for a failed training operation