Skip to content

Commit b46fc7b

Browse files
authored
feat(aiplatform): add tune model sample for Vertex LLMs (#8259)
1 parent b2d954c commit b46fc7b

File tree

2 files changed

+263
-0
lines changed

2 files changed

+263
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright 2023 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START aiplatform_sdk_tuning]
20+
import com.google.cloud.aiplatform.v1beta1.CreatePipelineJobRequest;
21+
import com.google.cloud.aiplatform.v1beta1.LocationName;
22+
import com.google.cloud.aiplatform.v1beta1.PipelineJob;
23+
import com.google.cloud.aiplatform.v1beta1.PipelineJob.RuntimeConfig;
24+
import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
25+
import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
26+
import com.google.protobuf.Value;
27+
import java.io.IOException;
28+
import java.util.HashMap;
29+
import java.util.Map;
30+
31+
public class CreatePipelineJobModelTuningSample {
32+
33+
public static void main(String[] args) throws IOException {
34+
// TODO(developer): Replace these variables before running the sample.
35+
String project = "PROJECT";
36+
String location =
37+
"europe-west4"; // Model tuning is only supported in europe-west4 for Public Preview
38+
String pipelineJobDisplayName = "PIPELINE_JOB_DISPLAY_NAME";
39+
String modelDisplayName = "MODEL_DISPLAY_NAME";
40+
String outputDir = "OUTPUT_DIR";
41+
String datasetUri = "DATASET_URI";
42+
int trainingSteps = 100;
43+
44+
createPipelineJobModelTuningSample(
45+
project,
46+
location,
47+
pipelineJobDisplayName,
48+
modelDisplayName,
49+
outputDir,
50+
datasetUri,
51+
trainingSteps);
52+
}
53+
54+
// Create a model tuning job
55+
public static void createPipelineJobModelTuningSample(
56+
String project,
57+
String location,
58+
String pipelineJobDisplayName,
59+
String modelDisplayName,
60+
String outputDir,
61+
String datasetUri,
62+
int trainingSteps)
63+
throws IOException {
64+
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
65+
PipelineServiceSettings pipelineServiceSettings =
66+
PipelineServiceSettings.newBuilder().setEndpoint(endpoint).build();
67+
68+
// Initialize client that will be used to send requests. This client only needs to be created
69+
// once, and can be reused for multiple requests.
70+
try (PipelineServiceClient client = PipelineServiceClient.create(pipelineServiceSettings)) {
71+
Map<String, Value> parameterValues = new HashMap<>();
72+
parameterValues.put("project", stringToValue(project));
73+
parameterValues.put("model_display_name", stringToValue(modelDisplayName));
74+
parameterValues.put("dataset_uri", stringToValue(datasetUri));
75+
parameterValues.put(
76+
"location",
77+
stringToValue(
78+
"us-central1")); // Deployment is only supported in us-central1 for Public Preview
79+
parameterValues.put("large_model_reference", stringToValue("text-bison@001"));
80+
parameterValues.put("train_steps", numberToValue(trainingSteps));
81+
82+
RuntimeConfig runtimeConfig =
83+
RuntimeConfig.newBuilder()
84+
.setGcsOutputDirectory(outputDir)
85+
.putAllParameterValues(parameterValues)
86+
.build();
87+
88+
PipelineJob pipelineJob =
89+
PipelineJob.newBuilder()
90+
.setTemplateUri(
91+
"https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v1.0.0")
92+
.setDisplayName(pipelineJobDisplayName)
93+
.setRuntimeConfig(runtimeConfig)
94+
.build();
95+
96+
LocationName parent = LocationName.of(project, location);
97+
CreatePipelineJobRequest request =
98+
CreatePipelineJobRequest.newBuilder()
99+
.setParent(parent.toString())
100+
.setPipelineJob(pipelineJob)
101+
.build();
102+
103+
PipelineJob response = client.createPipelineJob(request);
104+
System.out.format("response: %s\n", response);
105+
System.out.format("Name: %s\n", response.getName());
106+
}
107+
}
108+
109+
static Value stringToValue(String str) {
110+
return Value.newBuilder().setStringValue(str).build();
111+
}
112+
113+
static Value numberToValue(int n) {
114+
return Value.newBuilder().setNumberValue(n).build();
115+
}
116+
}
117+
118+
// [END aiplatform_sdk_tuning]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright 2023 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static junit.framework.TestCase.assertNotNull;
21+
22+
import com.google.api.gax.longrunning.OperationFuture;
23+
import com.google.cloud.aiplatform.v1beta1.DeleteOperationMetadata;
24+
import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient;
25+
import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings;
26+
import com.google.cloud.testing.junit4.MultipleAttemptsRule;
27+
import com.google.protobuf.Empty;
28+
import io.grpc.StatusRuntimeException;
29+
import java.io.ByteArrayOutputStream;
30+
import java.io.IOException;
31+
import java.io.PrintStream;
32+
import java.util.UUID;
33+
import java.util.concurrent.ExecutionException;
34+
import java.util.concurrent.TimeUnit;
35+
import java.util.concurrent.TimeoutException;
36+
import org.junit.After;
37+
import org.junit.Before;
38+
import org.junit.BeforeClass;
39+
import org.junit.Rule;
40+
import org.junit.Test;
41+
import org.junit.runner.RunWith;
42+
import org.junit.runners.JUnit4;
43+
44+
@RunWith(JUnit4.class)
45+
public class CreatePipelineJobModelTuningSampleTest {
46+
@Rule public final MultipleAttemptsRule multipleAttemptsRule = new MultipleAttemptsRule(3);
47+
48+
private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
49+
private static final String LOCATION = "europe-west4";
50+
private static final String OUTPUT_DIR =
51+
"gs://ucaip-samples-europe-west4/training_pipeline_output";
52+
private static final String DATASET_URI =
53+
"gs://cloud-samples-data/ai-platform/generative_ai/headline_classification.jsonl";
54+
private static final int TRAINING_STEPS = 100;
55+
private String pipelineJobName;
56+
private ByteArrayOutputStream bout;
57+
private PrintStream originalPrintStream;
58+
59+
private static void requireEnvVar(String varName) {
60+
String errorMessage =
61+
String.format("Environment variable '%s' is required to perform these tests.", varName);
62+
assertNotNull(errorMessage, System.getenv(varName));
63+
}
64+
65+
@BeforeClass
66+
public static void checkRequirements() {
67+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
68+
requireEnvVar("UCAIP_PROJECT_ID");
69+
}
70+
71+
@Before
72+
public void setUp() {
73+
bout = new ByteArrayOutputStream();
74+
PrintStream out = new PrintStream(bout);
75+
originalPrintStream = System.out;
76+
System.setOut(out);
77+
}
78+
79+
@After
80+
public void tearDown()
81+
throws IOException, InterruptedException, TimeoutException, ExecutionException {
82+
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", LOCATION);
83+
PipelineServiceSettings pipelineServiceSettings =
84+
PipelineServiceSettings.newBuilder().setEndpoint(endpoint).build();
85+
86+
try (PipelineServiceClient pipelineServiceClient =
87+
PipelineServiceClient.create(pipelineServiceSettings)) {
88+
// Cancel the PipelineJob
89+
pipelineServiceClient.cancelPipelineJob(pipelineJobName);
90+
TimeUnit.MINUTES.sleep(2);
91+
92+
// Delete the PipelineJob
93+
int retryCount = 3;
94+
while (retryCount > 0) {
95+
retryCount--;
96+
try {
97+
OperationFuture<Empty, DeleteOperationMetadata> operationFuture =
98+
pipelineServiceClient.deletePipelineJobAsync(pipelineJobName);
99+
operationFuture.get(300, TimeUnit.SECONDS);
100+
101+
// if delete operation is successful, break out of the loop and continue
102+
break;
103+
} catch (StatusRuntimeException e) {
104+
// wait for another 1 minute, then retry
105+
System.out.println("Retrying (due to unfinished cancellation operation)...");
106+
TimeUnit.MINUTES.sleep(1);
107+
} catch (Exception otherExceptions) {
108+
// other exception, let them throw
109+
throw otherExceptions;
110+
}
111+
}
112+
}
113+
114+
System.out.flush();
115+
System.setOut(originalPrintStream);
116+
}
117+
118+
@Test
119+
public void createTrainingPipelineModelTuningSample() throws IOException {
120+
final String pipelineJobDisplayName =
121+
String.format(
122+
"temp_create_pipeline_job_test_%s",
123+
UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
124+
125+
final String modelDisplayName =
126+
String.format(
127+
"temp_create_pipeline_job_model_test_%s",
128+
UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26));
129+
130+
// Act
131+
CreatePipelineJobModelTuningSample.createPipelineJobModelTuningSample(
132+
PROJECT,
133+
LOCATION,
134+
pipelineJobDisplayName,
135+
modelDisplayName,
136+
OUTPUT_DIR,
137+
DATASET_URI,
138+
TRAINING_STEPS);
139+
140+
// Assert
141+
String got = bout.toString();
142+
assertThat(got).contains(pipelineJobDisplayName);
143+
pipelineJobName = got.split("Name: ")[1].split("\n")[0];
144+
}
145+
}

0 commit comments

Comments
 (0)