Skip to content

Commit d179f77

Browse files
committed
add request configuration support
add stream request support
1 parent ea9c702 commit d179f77

File tree

4 files changed

+118
-2
lines changed

4 files changed

+118
-2
lines changed

README.md

+24-1
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ package com.codingapi.gemini.client;
2828

2929
import com.codingapi.gemini.pojo.Embedding;
3030
import com.codingapi.gemini.pojo.Generate;
31+
import lombok.SneakyThrows;
3132
import org.junit.jupiter.api.Test;
3233
import org.springframework.beans.factory.annotation.Autowired;
3334
import org.springframework.boot.test.context.SpringBootTest;
3435

3536
import java.io.File;
37+
import java.io.IOException;
3638
import java.util.List;
3739

3840
@SpringBootTest
@@ -49,8 +51,29 @@ class GeminiClientTest {
4951
System.out.println(answer);
5052
}
5153

54+
55+
@Test
56+
void generateConfiguration() {
57+
Generate.Request request = Generate.creatTextChart("你好,请用中文简体回答我,你如何看待区块链?");
58+
request.setGenerationConfig(new Generate.GenerationConfig(List.of("Title"), 1.0f, 1000, 0.8f, 10));
59+
request.addSafetySetting("HARM_CATEGORY_DANGEROUS_CONTENT", "BLOCK_ONLY_HIGH");
60+
Generate.Response response = client.generate(request);
61+
String answer = Generate.toAnswer(response);
62+
System.out.println(answer);
63+
}
64+
65+
@Test
66+
@SneakyThrows
67+
void stream() {
68+
Generate.Request request = Generate.creatTextChart("你好,请用中文简体回答我,你如何看待区块链?");
69+
client.stream(request, response -> {
70+
String answer = Generate.toAnswer(response);
71+
System.out.println(answer);
72+
});
73+
}
74+
5275
@Test
53-
void generateVision() {
76+
void generateVision() throws IOException {
5477
Generate.Request request = Generate.creatImageChart("这是一张什么图片?", new File("./images/test.png"));
5578
Generate.Response response = client.generate(request);
5679
String answer = Generate.toAnswer(response);

src/main/java/com/codingapi/gemini/client/GeminiClient.java

+27
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package com.codingapi.gemini.client;
22

3+
import com.alibaba.fastjson.JSONArray;
34
import com.alibaba.fastjson.JSONObject;
45
import com.codingapi.gemini.pojo.Embedding;
56
import com.codingapi.gemini.pojo.Generate;
67
import lombok.extern.slf4j.Slf4j;
8+
import org.springframework.core.io.Resource;
79
import org.springframework.http.HttpEntity;
810
import org.springframework.http.HttpHeaders;
911
import org.springframework.http.HttpMethod;
@@ -12,8 +14,13 @@
1214
import org.springframework.util.StringUtils;
1315
import org.springframework.web.client.RestTemplate;
1416

17+
import java.io.IOException;
18+
import java.io.InputStream;
1519
import java.net.InetSocketAddress;
1620
import java.net.Proxy;
21+
import java.util.List;
22+
import java.util.Objects;
23+
import java.util.function.Consumer;
1724

1825
@Slf4j
1926
public class GeminiClient {
@@ -40,6 +47,26 @@ public GeminiClient(String apiKey, String proxyHost, int proxyPort) {
4047
restTemplate.setRequestFactory(requestFactory);
4148
}
4249

50+
public void stream(Generate.Request request, Consumer<Generate.Response> consumer) throws IOException {
51+
String url = baseUrl + "models/gemini-pro:streamGenerateContent?key=" + apiKey;
52+
String json = request.toJSONString();
53+
log.info("json:{}", json);
54+
HttpEntity<String> httpEntity = new HttpEntity<>(json, headers);
55+
ResponseEntity<Resource> response = restTemplate.exchange(url, HttpMethod.POST, httpEntity, Resource.class);
56+
InputStream in = Objects.requireNonNull(response.getBody()).getInputStream();
57+
byte[] bytes = new byte[1024 * 8];
58+
int len;
59+
while ((len = in.read(bytes)) != -1) {
60+
String body = new String(bytes, 0, len);
61+
List<Generate.Response> responseList = JSONArray.parseArray(body, Generate.Response.class);
62+
for (Generate.Response res : responseList) {
63+
consumer.accept(res);
64+
}
65+
}
66+
in.close();
67+
}
68+
69+
4370
public Generate.Response generate(Generate.Request request) {
4471
String url;
4572
if (request.isVision()) {

src/main/java/com/codingapi/gemini/pojo/Generate.java

+45-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ public static Request creatImageChart(String text, File image) throws IOExceptio
3737
}
3838

3939

40-
4140
public static String toAnswer(Response response) {
4241
if (response == null || response.getCandidates() == null || response.getCandidates().isEmpty()) {
4342
return null;
@@ -46,6 +45,38 @@ public static String toAnswer(Response response) {
4645
}
4746

4847

48+
@Setter
49+
@Getter
50+
public static class SafetySetting {
51+
private String category;
52+
private String threshold;
53+
}
54+
55+
56+
@Setter
57+
@Getter
58+
public static class GenerationConfig {
59+
private List<String> stopSequences;
60+
private float temperature;
61+
private int maxOutputTokens;
62+
private float topP;
63+
private float topK;
64+
65+
66+
public GenerationConfig(List<String> stopSequences,
67+
float temperature,
68+
int maxOutputTokens,
69+
float topP,
70+
float topK) {
71+
this.stopSequences = stopSequences;
72+
this.temperature = temperature;
73+
this.maxOutputTokens = maxOutputTokens;
74+
this.topP = topP;
75+
this.topK = topK;
76+
}
77+
}
78+
79+
4980
@Setter
5081
@Getter
5182
public static class Request {
@@ -54,10 +85,23 @@ public static class Request {
5485
@JSONField(serialize = false)
5586
private boolean vision;
5687

88+
private List<SafetySetting> safetySettings;
89+
private GenerationConfig generationConfig;
90+
5791
public Request() {
5892
this.contents = new ArrayList<>();
5993
}
6094

95+
public void addSafetySetting(String category, String threshold) {
96+
if (safetySettings == null) {
97+
safetySettings = new ArrayList<>();
98+
}
99+
SafetySetting safetySetting = new SafetySetting();
100+
safetySetting.setCategory(category);
101+
safetySetting.setThreshold(threshold);
102+
safetySettings.add(safetySetting);
103+
}
104+
61105
public String toJSONString() {
62106
return JSONObject.toJSONString(this);
63107
}

src/test/java/com/codingapi/gemini/client/GeminiClientTest.java

+22
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.codingapi.gemini.pojo.Embedding;
44
import com.codingapi.gemini.pojo.Generate;
5+
import lombok.SneakyThrows;
56
import org.junit.jupiter.api.Test;
67
import org.springframework.beans.factory.annotation.Autowired;
78
import org.springframework.boot.test.context.SpringBootTest;
@@ -24,6 +25,27 @@ void generate() {
2425
System.out.println(answer);
2526
}
2627

28+
29+
@Test
30+
void generateConfiguration() {
31+
Generate.Request request = Generate.creatTextChart("你好,请用中文简体回答我,你如何看待区块链?");
32+
request.setGenerationConfig(new Generate.GenerationConfig(List.of("Title"), 1.0f, 1000, 0.8f, 10));
33+
request.addSafetySetting("HARM_CATEGORY_DANGEROUS_CONTENT", "BLOCK_ONLY_HIGH");
34+
Generate.Response response = client.generate(request);
35+
String answer = Generate.toAnswer(response);
36+
System.out.println(answer);
37+
}
38+
39+
@Test
40+
@SneakyThrows
41+
void stream() {
42+
Generate.Request request = Generate.creatTextChart("你好,请用中文简体回答我,你如何看待区块链?");
43+
client.stream(request, response -> {
44+
String answer = Generate.toAnswer(response);
45+
System.out.println(answer);
46+
});
47+
}
48+
2749
@Test
2850
void generateVision() throws IOException {
2951
Generate.Request request = Generate.creatImageChart("这是一张什么图片?", new File("./images/test.png"));

0 commit comments

Comments
 (0)