diff --git a/src/main/kotlin/dev/typetype/server/routes/DownloaderGatewayArtifactProxy.kt b/src/main/kotlin/dev/typetype/server/routes/DownloaderGatewayArtifactProxy.kt new file mode 100644 index 0000000..e3718d6 --- /dev/null +++ b/src/main/kotlin/dev/typetype/server/routes/DownloaderGatewayArtifactProxy.kt @@ -0,0 +1,84 @@ +package dev.typetype.server.routes + +import dev.typetype.server.models.ErrorResponse +import dev.typetype.server.services.DownloaderGatewayResponse +import dev.typetype.server.services.DownloaderGatewayService +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.ApplicationCall +import io.ktor.server.response.respond +import io.ktor.server.response.respondOutputStream +import okhttp3.Response + +suspend fun forwardDownloaderArtifactRequest( + call: ApplicationCall, + gateway: DownloaderGatewayService, + response: DownloaderGatewayResponse, + requestHeaders: Map, + forceDownload: Boolean, +) { + val location = artifactHeader(response, HttpHeaders.Location) + if (location == null) { + call.respond(HttpStatusCode.BadGateway, ErrorResponse("artifact unavailable")) + return + } + + val upstream = runCatching { gateway.openFetchAbsolute(location, requestHeaders) } + .getOrElse { + call.respond(HttpStatusCode.BadGateway, ErrorResponse("artifact unavailable")) + return + } + + val artifact = upstream + val headers = artifactHeaders(artifact) + headers.forEach { (name, value) -> + if (shouldForwardArtifactResponseHeader(name, forceDownload)) { + call.response.headers.append(name, value, safeOnly = false) + } + } + if (forceDownload) applyArtifactDownloadHeaders(call, artifactResponse(artifact, headers)) + + val status = HttpStatusCode.fromValue(artifact.code) + val contentType = artifactContentType(artifact, forceDownload) + try { + call.respondOutputStream(contentType = contentType, status = status) { + artifact.use { response -> + response.body.byteStream().use { input -> + input.copyTo(this, DEFAULT_BUFFER_SIZE) + } + } + } + } catch (error: Throwable) { + artifact.close() + throw error + } +} + +private fun artifactHeader(response: DownloaderGatewayResponse, name: String): String? = + response.headers.firstOrNull { it.first.equals(name, ignoreCase = true) }?.second + +private fun artifactHeaders(response: Response): List> = + response.headers.names().flatMap { name -> response.headers(name).map { name to it } } + +private fun artifactResponse(response: Response, headers: List>): DownloaderGatewayResponse = + DownloaderGatewayResponse( + status = response.code, + contentType = response.header(HttpHeaders.ContentType), + headers = headers, + body = ByteArray(0), + ) + +private fun artifactContentType(response: Response, forceDownload: Boolean): ContentType { + if (forceDownload) return ContentType.Application.OctetStream + return response.header(HttpHeaders.ContentType) + ?.let { runCatching { ContentType.parse(it) }.getOrNull() } + ?: ContentType.Application.OctetStream +} + +private fun shouldForwardArtifactResponseHeader(name: String, forceDownload: Boolean): Boolean { + val lower = name.lowercase() + if (lower == "transfer-encoding" || lower == "connection") return false + if (lower == "content-length") return true + return shouldForwardGatewayResponseHeader(name, forceDownload) +} diff --git a/src/main/kotlin/dev/typetype/server/routes/DownloaderGatewayRoutes.kt b/src/main/kotlin/dev/typetype/server/routes/DownloaderGatewayRoutes.kt index 6c33b4f..89da87a 100644 --- a/src/main/kotlin/dev/typetype/server/routes/DownloaderGatewayRoutes.kt +++ b/src/main/kotlin/dev/typetype/server/routes/DownloaderGatewayRoutes.kt @@ -62,38 +62,27 @@ private suspend fun forwardDownloaderRequest(call: ApplicationCall, gateway: Dow return } - val effectiveResponse = if (shouldProxyArtifact(path, response)) { - val location = headerValue(response, "Location") - if (location == null) { - response - } else { - runCatching { gateway.fetchAbsolute(location, requestHeaders) } - .getOrElse { - call.respond(HttpStatusCode.BadGateway, ErrorResponse("artifact unavailable")) - return - } - } - } else { - response - } - val forceDownload = shouldForceArtifactDownload(path, query) + if (shouldProxyArtifact(path, response)) { + forwardDownloaderArtifactRequest(call, gateway, response, requestHeaders, forceDownload) + return + } - effectiveResponse.headers.forEach { (name, value) -> + response.headers.forEach { (name, value) -> if (shouldForwardGatewayResponseHeader(name, forceDownload)) { call.response.headers.append(name, value, safeOnly = false) } } - if (forceDownload) applyArtifactDownloadHeaders(call, effectiveResponse) + if (forceDownload) applyArtifactDownloadHeaders(call, response) - val status = HttpStatusCode.fromValue(effectiveResponse.status) + val status = HttpStatusCode.fromValue(response.status) val contentType = if (forceDownload) { ContentType.Application.OctetStream } else { - effectiveResponse.contentType?.let { runCatching { ContentType.parse(it) }.getOrNull() } + response.contentType?.let { runCatching { ContentType.parse(it) }.getOrNull() } ?: ContentType.Application.OctetStream } - call.respondBytes(effectiveResponse.body, contentType = contentType, status = status) + call.respondBytes(response.body, contentType = contentType, status = status) } private fun hasRequestBody(method: String): Boolean = method == "POST" || method == "PUT" || method == "PATCH" diff --git a/src/main/kotlin/dev/typetype/server/services/DownloaderGatewayService.kt b/src/main/kotlin/dev/typetype/server/services/DownloaderGatewayService.kt index 8018248..db47985 100644 --- a/src/main/kotlin/dev/typetype/server/services/DownloaderGatewayService.kt +++ b/src/main/kotlin/dev/typetype/server/services/DownloaderGatewayService.kt @@ -44,18 +44,6 @@ class DownloaderGatewayService( return client.newCall(requestBuilder.build()).execute() } - fun fetchAbsolute(url: String, headers: Map): DownloaderGatewayResponse { - openFetchAbsolute(url, headers).use { response -> - val responseHeaders = response.headers.names().flatMap { name -> response.headers(name).map { name to it } } - return DownloaderGatewayResponse( - status = response.code, - contentType = response.header("Content-Type"), - headers = responseHeaders, - body = response.body.bytes(), - ) - } - } - fun openFetchAbsolute(url: String, headers: Map): Response { val requestBuilder = Request.Builder().url(url).method("GET", null) headers["Range"]?.takeIf { it.isNotBlank() }?.let { requestBuilder.addHeader("Range", it) } diff --git a/src/test/kotlin/dev/typetype/server/DownloaderGatewayArtifactProxyTest.kt b/src/test/kotlin/dev/typetype/server/DownloaderGatewayArtifactProxyTest.kt new file mode 100644 index 0000000..5943c57 --- /dev/null +++ b/src/test/kotlin/dev/typetype/server/DownloaderGatewayArtifactProxyTest.kt @@ -0,0 +1,74 @@ +package dev.typetype.server + +import com.sun.net.httpserver.HttpServer +import dev.typetype.server.routes.downloaderGatewayRoutes +import dev.typetype.server.services.DownloaderGatewayService +import io.ktor.client.request.get +import io.ktor.client.request.header +import io.ktor.client.statement.bodyAsText +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.server.routing.routing +import io.ktor.server.testing.testApplication +import java.net.InetAddress +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicReference +import okhttp3.Dns +import okhttp3.OkHttpClient +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test + +class DownloaderGatewayArtifactProxyTest { + @Test + fun `internal artifact redirect streams range response`() = testApplication { + val requestedRange = AtomicReference() + val upstream = HttpServer.create(InetSocketAddress(0), 0) + upstream.createContext("/jobs/test/artifact") { exchange -> + exchange.responseHeaders.add(HttpHeaders.Location, "http://garage:${upstream.address.port}/object") + exchange.sendResponseHeaders(302, -1) + exchange.close() + } + upstream.createContext("/object") { exchange -> + requestedRange.set(exchange.requestHeaders.getFirst(HttpHeaders.Range)) + val payload = "abc".toByteArray() + exchange.responseHeaders.add(HttpHeaders.ContentType, "video/mp4") + exchange.responseHeaders.add(HttpHeaders.ContentDisposition, "inline; filename=demo.mp4") + exchange.responseHeaders.add(HttpHeaders.AcceptRanges, "bytes") + exchange.responseHeaders.add(HttpHeaders.ContentRange, "bytes 0-2/6") + exchange.sendResponseHeaders(206, payload.size.toLong()) + exchange.responseBody.use { it.write(payload) } + } + upstream.start() + + val gateway = DownloaderGatewayService( + baseUrl = "http://127.0.0.1:${upstream.address.port}", + client = OkHttpClient.Builder().dns(testDns()).followRedirects(false).followSslRedirects(false).build(), + ) + + application { + routing { + downloaderGatewayRoutes(gateway) + } + } + + try { + val response = client.get("/downloader/jobs/test/artifact") { + header(HttpHeaders.Range, "bytes=0-2") + } + assertEquals(HttpStatusCode.PartialContent, response.status) + assertEquals("bytes=0-2", requestedRange.get()) + assertEquals("3", response.headers[HttpHeaders.ContentLength]) + assertEquals("bytes", response.headers[HttpHeaders.AcceptRanges]) + assertEquals("bytes 0-2/6", response.headers[HttpHeaders.ContentRange]) + assertTrue(response.headers[HttpHeaders.ContentDisposition].orEmpty().startsWith("attachment")) + assertEquals("abc", response.bodyAsText()) + } finally { + upstream.stop(0) + } + } + + private fun testDns(): Dns = Dns { hostname -> + if (hostname == "garage") listOf(InetAddress.getByName("127.0.0.1")) else Dns.SYSTEM.lookup(hostname) + } +}