Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package dev.typetype.server.routes

import dev.typetype.server.models.ErrorResponse
import dev.typetype.server.services.DownloaderGatewayResponse
import dev.typetype.server.services.DownloaderGatewayService
import io.ktor.http.ContentType
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.server.application.ApplicationCall
import io.ktor.server.response.respond
import io.ktor.server.response.respondOutputStream
import okhttp3.Response

suspend fun forwardDownloaderArtifactRequest(
call: ApplicationCall,
gateway: DownloaderGatewayService,
response: DownloaderGatewayResponse,
requestHeaders: Map<String, String>,
forceDownload: Boolean,
) {
val location = artifactHeader(response, HttpHeaders.Location)
if (location == null) {
call.respond(HttpStatusCode.BadGateway, ErrorResponse("artifact unavailable"))
return
}

val upstream = runCatching { gateway.openFetchAbsolute(location, requestHeaders) }
.getOrElse {
call.respond(HttpStatusCode.BadGateway, ErrorResponse("artifact unavailable"))
return
}

val artifact = upstream
val headers = artifactHeaders(artifact)
headers.forEach { (name, value) ->
if (shouldForwardArtifactResponseHeader(name, forceDownload)) {
call.response.headers.append(name, value, safeOnly = false)
}
}
if (forceDownload) applyArtifactDownloadHeaders(call, artifactResponse(artifact, headers))

val status = HttpStatusCode.fromValue(artifact.code)
val contentType = artifactContentType(artifact, forceDownload)
try {
call.respondOutputStream(contentType = contentType, status = status) {
artifact.use { response ->
response.body.byteStream().use { input ->
input.copyTo(this, DEFAULT_BUFFER_SIZE)
}
}
}
} catch (error: Throwable) {
artifact.close()
throw error
}
}

private fun artifactHeader(response: DownloaderGatewayResponse, name: String): String? =
response.headers.firstOrNull { it.first.equals(name, ignoreCase = true) }?.second

private fun artifactHeaders(response: Response): List<Pair<String, String>> =
response.headers.names().flatMap { name -> response.headers(name).map { name to it } }

private fun artifactResponse(response: Response, headers: List<Pair<String, String>>): DownloaderGatewayResponse =
DownloaderGatewayResponse(
status = response.code,
contentType = response.header(HttpHeaders.ContentType),
headers = headers,
body = ByteArray(0),
)

private fun artifactContentType(response: Response, forceDownload: Boolean): ContentType {
if (forceDownload) return ContentType.Application.OctetStream
return response.header(HttpHeaders.ContentType)
?.let { runCatching { ContentType.parse(it) }.getOrNull() }
?: ContentType.Application.OctetStream
}

private fun shouldForwardArtifactResponseHeader(name: String, forceDownload: Boolean): Boolean {
val lower = name.lowercase()
if (lower == "transfer-encoding" || lower == "connection") return false
if (lower == "content-length") return true
return shouldForwardGatewayResponseHeader(name, forceDownload)
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,38 +62,27 @@ private suspend fun forwardDownloaderRequest(call: ApplicationCall, gateway: Dow
return
}

val effectiveResponse = if (shouldProxyArtifact(path, response)) {
val location = headerValue(response, "Location")
if (location == null) {
response
} else {
runCatching { gateway.fetchAbsolute(location, requestHeaders) }
.getOrElse {
call.respond(HttpStatusCode.BadGateway, ErrorResponse("artifact unavailable"))
return
}
}
} else {
response
}

val forceDownload = shouldForceArtifactDownload(path, query)
if (shouldProxyArtifact(path, response)) {
forwardDownloaderArtifactRequest(call, gateway, response, requestHeaders, forceDownload)
return
}

effectiveResponse.headers.forEach { (name, value) ->
response.headers.forEach { (name, value) ->
if (shouldForwardGatewayResponseHeader(name, forceDownload)) {
call.response.headers.append(name, value, safeOnly = false)
}
}
if (forceDownload) applyArtifactDownloadHeaders(call, effectiveResponse)
if (forceDownload) applyArtifactDownloadHeaders(call, response)

val status = HttpStatusCode.fromValue(effectiveResponse.status)
val status = HttpStatusCode.fromValue(response.status)
val contentType = if (forceDownload) {
ContentType.Application.OctetStream
} else {
effectiveResponse.contentType?.let { runCatching { ContentType.parse(it) }.getOrNull() }
response.contentType?.let { runCatching { ContentType.parse(it) }.getOrNull() }
?: ContentType.Application.OctetStream
}
call.respondBytes(effectiveResponse.body, contentType = contentType, status = status)
call.respondBytes(response.body, contentType = contentType, status = status)
}

private fun hasRequestBody(method: String): Boolean = method == "POST" || method == "PUT" || method == "PATCH"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,6 @@ class DownloaderGatewayService(
return client.newCall(requestBuilder.build()).execute()
}

fun fetchAbsolute(url: String, headers: Map<String, String>): DownloaderGatewayResponse {
openFetchAbsolute(url, headers).use { response ->
val responseHeaders = response.headers.names().flatMap { name -> response.headers(name).map { name to it } }
return DownloaderGatewayResponse(
status = response.code,
contentType = response.header("Content-Type"),
headers = responseHeaders,
body = response.body.bytes(),
)
}
}

fun openFetchAbsolute(url: String, headers: Map<String, String>): Response {
val requestBuilder = Request.Builder().url(url).method("GET", null)
headers["Range"]?.takeIf { it.isNotBlank() }?.let { requestBuilder.addHeader("Range", it) }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package dev.typetype.server

import com.sun.net.httpserver.HttpServer
import dev.typetype.server.routes.downloaderGatewayRoutes
import dev.typetype.server.services.DownloaderGatewayService
import io.ktor.client.request.get
import io.ktor.client.request.header
import io.ktor.client.statement.bodyAsText
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.server.routing.routing
import io.ktor.server.testing.testApplication
import java.net.InetAddress
import java.net.InetSocketAddress
import java.util.concurrent.atomic.AtomicReference
import okhttp3.Dns
import okhttp3.OkHttpClient
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test

class DownloaderGatewayArtifactProxyTest {
@Test
fun `internal artifact redirect streams range response`() = testApplication {
val requestedRange = AtomicReference<String>()
val upstream = HttpServer.create(InetSocketAddress(0), 0)
upstream.createContext("/jobs/test/artifact") { exchange ->
exchange.responseHeaders.add(HttpHeaders.Location, "http://garage:${upstream.address.port}/object")
exchange.sendResponseHeaders(302, -1)
exchange.close()
}
upstream.createContext("/object") { exchange ->
requestedRange.set(exchange.requestHeaders.getFirst(HttpHeaders.Range))
val payload = "abc".toByteArray()
exchange.responseHeaders.add(HttpHeaders.ContentType, "video/mp4")
exchange.responseHeaders.add(HttpHeaders.ContentDisposition, "inline; filename=demo.mp4")
exchange.responseHeaders.add(HttpHeaders.AcceptRanges, "bytes")
exchange.responseHeaders.add(HttpHeaders.ContentRange, "bytes 0-2/6")
exchange.sendResponseHeaders(206, payload.size.toLong())
exchange.responseBody.use { it.write(payload) }
}
upstream.start()

val gateway = DownloaderGatewayService(
baseUrl = "http://127.0.0.1:${upstream.address.port}",
client = OkHttpClient.Builder().dns(testDns()).followRedirects(false).followSslRedirects(false).build(),
)

application {
routing {
downloaderGatewayRoutes(gateway)
}
}

try {
val response = client.get("/downloader/jobs/test/artifact") {
header(HttpHeaders.Range, "bytes=0-2")
}
assertEquals(HttpStatusCode.PartialContent, response.status)
assertEquals("bytes=0-2", requestedRange.get())
assertEquals("3", response.headers[HttpHeaders.ContentLength])
assertEquals("bytes", response.headers[HttpHeaders.AcceptRanges])
assertEquals("bytes 0-2/6", response.headers[HttpHeaders.ContentRange])
assertTrue(response.headers[HttpHeaders.ContentDisposition].orEmpty().startsWith("attachment"))
assertEquals("abc", response.bodyAsText())
} finally {
upstream.stop(0)
}
}

private fun testDns(): Dns = Dns { hostname ->
if (hostname == "garage") listOf(InetAddress.getByName("127.0.0.1")) else Dns.SYSTEM.lookup(hostname)
}
}