|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the Elastic License |
| 4 | + * 2.0; you may not use this file except in compliance with the Elastic License |
| 5 | + * 2.0. |
| 6 | + */ |
| 7 | + |
| 8 | +package org.elasticsearch.xpack.inference; |
| 9 | + |
| 10 | +import com.sun.net.httpserver.HttpExchange; |
| 11 | +import com.sun.net.httpserver.HttpServer; |
| 12 | + |
| 13 | +import org.apache.http.HttpHeaders; |
| 14 | +import org.apache.http.HttpStatus; |
| 15 | +import org.apache.http.client.utils.URIBuilder; |
| 16 | +import org.elasticsearch.logging.LogManager; |
| 17 | +import org.elasticsearch.logging.Logger; |
| 18 | +import org.elasticsearch.test.fixture.HttpHeaderParser; |
| 19 | +import org.elasticsearch.xcontent.XContentParser; |
| 20 | +import org.elasticsearch.xcontent.XContentParserConfiguration; |
| 21 | +import org.elasticsearch.xcontent.XContentType; |
| 22 | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; |
| 23 | +import org.junit.rules.TestRule; |
| 24 | +import org.junit.runner.Description; |
| 25 | +import org.junit.runners.model.Statement; |
| 26 | + |
| 27 | +import java.io.ByteArrayInputStream; |
| 28 | +import java.io.IOException; |
| 29 | +import java.io.InputStream; |
| 30 | +import java.io.OutputStream; |
| 31 | +import java.net.InetSocketAddress; |
| 32 | +import java.nio.charset.StandardCharsets; |
| 33 | +import java.util.Random; |
| 34 | +import java.util.concurrent.ExecutorService; |
| 35 | +import java.util.concurrent.Executors; |
| 36 | + |
| 37 | +/** |
| 38 | + * Simple model server to serve ML models. |
| 39 | + * The URL path corresponds to a file name in this class's resources. |
| 40 | + * If the file is found, its content is returned, otherwise 404. |
| 41 | + * Respects a range header to serve partial content. |
| 42 | + */ |
| 43 | +public class MlModelServer implements TestRule { |
| 44 | + |
| 45 | + private static final String HOST = "localhost"; |
| 46 | + private static final Logger logger = LogManager.getLogger(MlModelServer.class); |
| 47 | + |
| 48 | + private int port; |
| 49 | + |
| 50 | + public String getUrl() { |
| 51 | + return new URIBuilder().setScheme("http").setHost(HOST).setPort(port).toString(); |
| 52 | + } |
| 53 | + |
| 54 | + private void handle(HttpExchange exchange) throws IOException { |
| 55 | + String rangeHeader = exchange.getRequestHeaders().getFirst(HttpHeaders.RANGE); |
| 56 | + HttpHeaderParser.Range range = rangeHeader != null ? HttpHeaderParser.parseRangeHeader(rangeHeader) : null; |
| 57 | + logger.info("request: {} range={}", exchange.getRequestURI().getPath(), range); |
| 58 | + |
| 59 | + try (InputStream is = getInputStream(exchange)) { |
| 60 | + int httpStatus; |
| 61 | + long numBytes; |
| 62 | + if (is == null) { |
| 63 | + httpStatus = HttpStatus.SC_NOT_FOUND; |
| 64 | + numBytes = 0; |
| 65 | + } else if (range == null) { |
| 66 | + httpStatus = HttpStatus.SC_OK; |
| 67 | + numBytes = is.available(); |
| 68 | + } else { |
| 69 | + httpStatus = HttpStatus.SC_PARTIAL_CONTENT; |
| 70 | + is.skipNBytes(range.start()); |
| 71 | + numBytes = range.end() - range.start() + 1; |
| 72 | + } |
| 73 | + logger.info("response: {} {}", exchange.getRequestURI().getPath(), httpStatus); |
| 74 | + exchange.sendResponseHeaders(httpStatus, numBytes); |
| 75 | + try (OutputStream os = exchange.getResponseBody()) { |
| 76 | + while (numBytes > 0) { |
| 77 | + byte[] bytes = is.readNBytes((int) Math.min(1 << 20, numBytes)); |
| 78 | + os.write(bytes); |
| 79 | + numBytes -= bytes.length; |
| 80 | + } |
| 81 | + } |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + private InputStream getInputStream(HttpExchange exchange) throws IOException { |
| 86 | + String path = exchange.getRequestURI().getPath().substring(1); // Strip leading slash |
| 87 | + String modelId = path.substring(0, path.indexOf('.')); |
| 88 | + String extension = path.substring(path.indexOf('.') + 1); |
| 89 | + |
| 90 | + // If a model specifically optimized for some platform is requested, |
| 91 | + // serve the default non-optimized model instead, which is compatible. |
| 92 | + String defaultModelId = modelId.replace("_linux-x86_64", ""); |
| 93 | + |
| 94 | + ClassLoader classloader = Thread.currentThread().getContextClassLoader(); |
| 95 | + InputStream is = classloader.getResourceAsStream(defaultModelId + "." + extension); |
| 96 | + if (is != null && modelId.equals(defaultModelId) == false && extension.equals("metadata.json")) { |
| 97 | + // When an optimized version is requested, fix the default metadata, |
| 98 | + // so that it contains the correct model ID. |
| 99 | + try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, is.readAllBytes())) { |
| 100 | + is.close(); |
| 101 | + ModelPackageConfig packageConfig = ModelPackageConfig.fromXContentLenient(parser); |
| 102 | + packageConfig = new ModelPackageConfig.Builder(packageConfig).setPackedModelId(modelId).build(); |
| 103 | + is = new ByteArrayInputStream(packageConfig.toString().getBytes(StandardCharsets.UTF_8)); |
| 104 | + } |
| 105 | + } |
| 106 | + return is; |
| 107 | + } |
| 108 | + |
| 109 | + @Override |
| 110 | + public Statement apply(Statement statement, Description description) { |
| 111 | + return new Statement() { |
| 112 | + @Override |
| 113 | + public void evaluate() throws Throwable { |
| 114 | + logger.info("Starting ML model server"); |
| 115 | + HttpServer server = HttpServer.create(); |
| 116 | + while (true) { |
| 117 | + port = new Random().nextInt(10000, 65536); |
| 118 | + try { |
| 119 | + server.bind(new InetSocketAddress(HOST, port), 1); |
| 120 | + } catch (Exception e) { |
| 121 | + continue; |
| 122 | + } |
| 123 | + break; |
| 124 | + } |
| 125 | + logger.info("Bound ML model server to port {}", port); |
| 126 | + |
| 127 | + ExecutorService executor = Executors.newCachedThreadPool(); |
| 128 | + server.setExecutor(executor); |
| 129 | + server.createContext("/", MlModelServer.this::handle); |
| 130 | + server.start(); |
| 131 | + |
| 132 | + try { |
| 133 | + statement.evaluate(); |
| 134 | + } finally { |
| 135 | + logger.info("Stopping ML model server on port {}", port); |
| 136 | + server.stop(1); |
| 137 | + executor.shutdown(); |
| 138 | + } |
| 139 | + } |
| 140 | + }; |
| 141 | + } |
| 142 | +} |
0 commit comments