Skip to content

Commit 4c01807

Browse files
[8.17] Test ML model server (#120270) (#120588)
* Test ML model server (#120270) * Fix model downloading for very small models. * Test MlModelServer * Tiny ELSER * unmute TextEmbeddingCrudIT and DefaultEndPointsIT * update ELSER * Improve MlModelServer * tiny E5 * more logging * improved E5 model * tiny reranker * scan for ports * [CI] Auto commit changes from spotless * Serve default models when optimized model is requested * @ClassRule * polish code * Respect dynamic setting ML model repo * fix metadata for optimized models * improve logging --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co> * backport HttpHeaderParser * Fix stripping platform --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
1 parent 8059034 commit 4c01807

File tree

19 files changed

+353
-26
lines changed

19 files changed

+353
-26
lines changed

muted-tests.yml

-18
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,6 @@ tests:
199199
- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
200200
method: test {categorize.Categorize ASYNC}
201201
issue: https://github.com/elastic/elasticsearch/issues/116373
202-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
203-
method: testPutE5WithTrainedModelAndInference
204-
issue: https://github.com/elastic/elasticsearch/issues/114023
205-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
206-
method: testPutE5Small_withPlatformAgnosticVariant
207-
issue: https://github.com/elastic/elasticsearch/issues/113983
208202
- class: org.elasticsearch.datastreams.LazyRolloverDuringDisruptionIT
209203
method: testRolloverIsExecutedOnce
210204
issue: https://github.com/elastic/elasticsearch/issues/112634
@@ -214,9 +208,6 @@ tests:
214208
- class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityWithApmTracingRestIT
215209
method: testTracingCrossCluster
216210
issue: https://github.com/elastic/elasticsearch/issues/112731
217-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
218-
method: testPutE5Small_withPlatformSpecificVariant
219-
issue: https://github.com/elastic/elasticsearch/issues/113950
220211
- class: org.elasticsearch.smoketest.DocsClientYamlTestSuiteIT
221212
method: test {yaml=reference/rest-api/usage/line_38}
222213
issue: https://github.com/elastic/elasticsearch/issues/113694
@@ -226,9 +217,6 @@ tests:
226217
- class: org.elasticsearch.reservedstate.service.FileSettingsServiceTests
227218
method: testProcessFileChanges
228219
issue: https://github.com/elastic/elasticsearch/issues/115280
229-
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
230-
method: testInferDeploysDefaultE5
231-
issue: https://github.com/elastic/elasticsearch/issues/115361
232220
- class: org.elasticsearch.xpack.inference.InferenceCrudIT
233221
method: testSupportedStream
234222
issue: https://github.com/elastic/elasticsearch/issues/113430
@@ -285,9 +273,6 @@ tests:
285273
- class: org.elasticsearch.xpack.esql.qa.mixed.EsqlClientYamlIT
286274
method: test {p0=esql/61_enrich_ip/IP strings}
287275
issue: https://github.com/elastic/elasticsearch/issues/116529
288-
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
289-
method: testInferDeploysDefaultElser
290-
issue: https://github.com/elastic/elasticsearch/issues/114913
291276
- class: org.elasticsearch.threadpool.SimpleThreadPoolIT
292277
method: testThreadPoolMetrics
293278
issue: https://github.com/elastic/elasticsearch/issues/108320
@@ -336,9 +321,6 @@ tests:
336321
- class: org.elasticsearch.xpack.searchablesnapshots.RetrySearchIntegTests
337322
method: testRetryPointInTime
338323
issue: https://github.com/elastic/elasticsearch/issues/117116
339-
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
340-
method: testMultipleInferencesTriggeringDownloadAndDeploy
341-
issue: https://github.com/elastic/elasticsearch/issues/117208
342324
- class: org.elasticsearch.xpack.spatial.search.GeoGridAggAndQueryConsistencyIT
343325
method: testGeoPointGeoTile
344326
issue: https://github.com/elastic/elasticsearch/issues/115818
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.test.fixture;
11+
12+
import java.util.regex.Matcher;
13+
import java.util.regex.Pattern;
14+
15+
public enum HttpHeaderParser {
16+
;
17+
18+
private static final Pattern RANGE_HEADER_PATTERN = Pattern.compile("bytes=([0-9]+)-([0-9]+)");
19+
20+
/**
21+
* Parse a "Range" header
22+
*
23+
* Note: only a single bounded range is supported (e.g. <code>Range: bytes={range_start}-{range_end}</code>)
24+
*
25+
* @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Range">MDN: Range header</a>
26+
* @param rangeHeaderValue The header value as a string
27+
* @return a {@link Range} instance representing the parsed value, or null if the header is malformed
28+
*/
29+
public static Range parseRangeHeader(String rangeHeaderValue) {
30+
final Matcher matcher = RANGE_HEADER_PATTERN.matcher(rangeHeaderValue);
31+
if (matcher.matches()) {
32+
try {
33+
return new Range(Long.parseLong(matcher.group(1)), Long.parseLong(matcher.group(2)));
34+
} catch (NumberFormatException e) {
35+
return null;
36+
}
37+
}
38+
return null;
39+
}
40+
41+
public record Range(long start, long end) {}
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.http;
11+
12+
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.test.ESTestCase;
14+
import org.elasticsearch.test.fixture.HttpHeaderParser;
15+
16+
import java.math.BigInteger;
17+
18+
public class HttpHeaderParserTests extends ESTestCase {
19+
20+
public void testParseRangeHeader() {
21+
final long start = randomLongBetween(0, 10_000);
22+
final long end = randomLongBetween(start, start + 10_000);
23+
assertEquals(new HttpHeaderParser.Range(start, end), HttpHeaderParser.parseRangeHeader("bytes=" + start + "-" + end));
24+
}
25+
26+
public void testParseRangeHeaderInvalidLong() {
27+
final BigInteger longOverflow = BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE).add(randomBigInteger());
28+
assertNull(HttpHeaderParser.parseRangeHeader("bytes=123-" + longOverflow));
29+
assertNull(HttpHeaderParser.parseRangeHeader("bytes=" + longOverflow + "-123"));
30+
}
31+
32+
public void testParseRangeHeaderMultipleRangesNotMatched() {
33+
assertNull(
34+
HttpHeaderParser.parseRangeHeader(
35+
Strings.format(
36+
"bytes=%d-%d,%d-%d",
37+
randomIntBetween(0, 99),
38+
randomIntBetween(100, 199),
39+
randomIntBetween(200, 299),
40+
randomIntBetween(300, 399)
41+
)
42+
)
43+
);
44+
}
45+
46+
public void testParseRangeHeaderEndlessRangeNotMatched() {
47+
assertNull(HttpHeaderParser.parseRangeHeader(Strings.format("bytes=%d-", randomLongBetween(0, Long.MAX_VALUE))));
48+
}
49+
50+
public void testParseRangeHeaderSuffixLengthNotMatched() {
51+
assertNull(HttpHeaderParser.parseRangeHeader(Strings.format("bytes=-%d", randomLongBetween(0, Long.MAX_VALUE))));
52+
}
53+
}

x-pack/plugin/inference/qa/inference-service-tests/build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
apply plugin: 'elasticsearch.internal-java-rest-test'
22

33
dependencies {
4+
javaRestTestImplementation project(path: xpackModule('core'))
45
javaRestTestImplementation project(path: xpackModule('inference'))
56
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
67
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

+18
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
2323
import org.elasticsearch.test.rest.ESRestTestCase;
2424
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
25+
import org.junit.Before;
2526
import org.junit.ClassRule;
2627

2728
import java.io.IOException;
@@ -37,6 +38,7 @@
3738
import static org.hamcrest.Matchers.hasSize;
3839

3940
public class InferenceBaseRestTest extends ESRestTestCase {
41+
4042
@ClassRule
4143
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
4244
.distribution(DistributionType.DEFAULT)
@@ -46,6 +48,22 @@ public class InferenceBaseRestTest extends ESRestTestCase {
4648
.user("x_pack_rest_user", "x-pack-test-password")
4749
.build();
4850

51+
@ClassRule
52+
public static MlModelServer mlModelServer = new MlModelServer();
53+
54+
@Before
55+
public void setMlModelRepository() throws IOException {
56+
logger.info("setting ML model repository to: {}", mlModelServer.getUrl());
57+
var request = new Request("PUT", "/_cluster/settings");
58+
request.setJsonEntity(Strings.format("""
59+
{
60+
"persistent": {
61+
"xpack.ml.model_repository": "%s"
62+
}
63+
}""", mlModelServer.getUrl()));
64+
assertOK(client().performRequest(request));
65+
}
66+
4967
@Override
5068
protected String getTestRestCluster() {
5169
return cluster.getHttpAddresses();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import com.sun.net.httpserver.HttpExchange;
11+
import com.sun.net.httpserver.HttpServer;
12+
13+
import org.apache.http.HttpHeaders;
14+
import org.apache.http.HttpStatus;
15+
import org.apache.http.client.utils.URIBuilder;
16+
import org.elasticsearch.logging.LogManager;
17+
import org.elasticsearch.logging.Logger;
18+
import org.elasticsearch.test.fixture.HttpHeaderParser;
19+
import org.elasticsearch.xcontent.XContentParser;
20+
import org.elasticsearch.xcontent.XContentParserConfiguration;
21+
import org.elasticsearch.xcontent.XContentType;
22+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
23+
import org.junit.rules.TestRule;
24+
import org.junit.runner.Description;
25+
import org.junit.runners.model.Statement;
26+
27+
import java.io.ByteArrayInputStream;
28+
import java.io.IOException;
29+
import java.io.InputStream;
30+
import java.io.OutputStream;
31+
import java.net.InetSocketAddress;
32+
import java.nio.charset.StandardCharsets;
33+
import java.util.Random;
34+
import java.util.concurrent.ExecutorService;
35+
import java.util.concurrent.Executors;
36+
37+
/**
38+
* Simple model server to serve ML models.
39+
* The URL path corresponds to a file name in this class's resources.
40+
* If the file is found, its content is returned, otherwise 404.
41+
* Respects a range header to serve partial content.
42+
*/
43+
public class MlModelServer implements TestRule {
44+
45+
private static final String HOST = "localhost";
46+
private static final Logger logger = LogManager.getLogger(MlModelServer.class);
47+
48+
private int port;
49+
50+
public String getUrl() {
51+
return new URIBuilder().setScheme("http").setHost(HOST).setPort(port).toString();
52+
}
53+
54+
private void handle(HttpExchange exchange) throws IOException {
55+
String rangeHeader = exchange.getRequestHeaders().getFirst(HttpHeaders.RANGE);
56+
HttpHeaderParser.Range range = rangeHeader != null ? HttpHeaderParser.parseRangeHeader(rangeHeader) : null;
57+
logger.info("request: {} range={}", exchange.getRequestURI().getPath(), range);
58+
59+
try (InputStream is = getInputStream(exchange)) {
60+
int httpStatus;
61+
long numBytes;
62+
if (is == null) {
63+
httpStatus = HttpStatus.SC_NOT_FOUND;
64+
numBytes = 0;
65+
} else if (range == null) {
66+
httpStatus = HttpStatus.SC_OK;
67+
numBytes = is.available();
68+
} else {
69+
httpStatus = HttpStatus.SC_PARTIAL_CONTENT;
70+
is.skipNBytes(range.start());
71+
numBytes = range.end() - range.start() + 1;
72+
}
73+
logger.info("response: {} {}", exchange.getRequestURI().getPath(), httpStatus);
74+
exchange.sendResponseHeaders(httpStatus, numBytes);
75+
try (OutputStream os = exchange.getResponseBody()) {
76+
while (numBytes > 0) {
77+
byte[] bytes = is.readNBytes((int) Math.min(1 << 20, numBytes));
78+
os.write(bytes);
79+
numBytes -= bytes.length;
80+
}
81+
}
82+
}
83+
}
84+
85+
private InputStream getInputStream(HttpExchange exchange) throws IOException {
86+
String path = exchange.getRequestURI().getPath().substring(1); // Strip leading slash
87+
String modelId = path.substring(0, path.indexOf('.'));
88+
String extension = path.substring(path.indexOf('.') + 1);
89+
90+
// If a model specifically optimized for some platform is requested,
91+
// serve the default non-optimized model instead, which is compatible.
92+
String defaultModelId = modelId.replace("_linux-x86_64", "");
93+
94+
ClassLoader classloader = Thread.currentThread().getContextClassLoader();
95+
InputStream is = classloader.getResourceAsStream(defaultModelId + "." + extension);
96+
if (is != null && modelId.equals(defaultModelId) == false && extension.equals("metadata.json")) {
97+
// When an optimized version is requested, fix the default metadata,
98+
// so that it contains the correct model ID.
99+
try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, is.readAllBytes())) {
100+
is.close();
101+
ModelPackageConfig packageConfig = ModelPackageConfig.fromXContentLenient(parser);
102+
packageConfig = new ModelPackageConfig.Builder(packageConfig).setPackedModelId(modelId).build();
103+
is = new ByteArrayInputStream(packageConfig.toString().getBytes(StandardCharsets.UTF_8));
104+
}
105+
}
106+
return is;
107+
}
108+
109+
@Override
110+
public Statement apply(Statement statement, Description description) {
111+
return new Statement() {
112+
@Override
113+
public void evaluate() throws Throwable {
114+
logger.info("Starting ML model server");
115+
HttpServer server = HttpServer.create();
116+
while (true) {
117+
port = new Random().nextInt(10000, 65536);
118+
try {
119+
server.bind(new InetSocketAddress(HOST, port), 1);
120+
} catch (Exception e) {
121+
continue;
122+
}
123+
break;
124+
}
125+
logger.info("Bound ML model server to port {}", port);
126+
127+
ExecutorService executor = Executors.newCachedThreadPool();
128+
server.setExecutor(executor);
129+
server.createContext("/", MlModelServer.this::handle);
130+
server.start();
131+
132+
try {
133+
statement.evaluate();
134+
} finally {
135+
logger.info("Stopping ML model server on port {}", port);
136+
server.stop(1);
137+
executor.shutdown();
138+
}
139+
}
140+
};
141+
}
142+
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
import static org.hamcrest.Matchers.containsString;
2020

21-
// This test was previously disabled in CI due to the models being too large
22-
// See "https://github.com/elastic/elasticsearch/issues/105198".
2321
public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
2422

2523
public void testPutE5Small_withNoModelVariant() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"packaged_model_id": "elser_model_2",
3+
"minimum_version": "11.0.0",
4+
"size": 1859242,
5+
"sha256": "602dbccfb2746e5700bf65d8019b06fb2ec1e3c5bfb980eb2005fc17c1bfe0c0",
6+
"description": "Elastic Learned Sparse EncodeR v2",
7+
"model_type": "pytorch",
8+
"tags": [
9+
"elastic"
10+
],
11+
"inference_config": {
12+
"text_expansion": {
13+
"tokenization": {
14+
"bert": {
15+
"do_lower_case": true,
16+
"with_special_tokens": true,
17+
"max_sequence_length": 512,
18+
"truncate": "first",
19+
"span": -1
20+
}
21+
}
22+
}
23+
},
24+
"vocabulary_file": "elser_model_2.vocab.json"
25+
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/resources/elser_model_2.vocab.json

+1
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)