
@Test
@Disabled("需要本地运行 Ollama 服务")
public void testOllamaEmbedding() {
// Ollama API 地址
String apiUrl = "http://localhost:11434/api/embeddings";
String apiKey = ""; // Ollama 本地不需要 key
String model = "nomic-embed-text"; // 或 mxbai-embed-large
EmbeddingClient client = new EmbeddingClientImpl(apiUrl, apiKey);
// 水果库
List fruits = Arrays.asList(new Fruit("红富士苹果", "红色 甜 脆 苹果 新鲜"), new Fruit("青苹果", "绿色 酸 脆 苹果 清爽"),
new Fruit("金帅苹果", "黄色 甜 软 苹果"), new Fruit("香蕉", "黄色 甜 软 香蕉 热带水果"), new Fruit("草莓", "红色 甜 小 草莓 多汁 浆果"),
new Fruit("西瓜", "绿色外皮 红色果肉 甜 大 西瓜 多汁 夏天"), new Fruit("葡萄", "紫色 甜 小 葡萄 多汁 成串"));
// 为每个水果生成嵌入向量
for (Fruit fruit : fruits) {
fruit.embedding = client.getEmbeddingVector(model, fruit.description);
}
// 用户搜索
String query = "红色的甜水果";
double[] queryVector = client.getEmbeddingVector(model, query);
System.out.println("搜索: "" + query + """);
System.out.println("向量维度: " + queryVector.length);
System.out.println();
// 按相似度排序
fruits.sort(Comparator.comparingDouble(f -> -cosineSimilarity(queryVector, f.embedding)));
// 输出结果
System.out.println("搜索结果(按相似度排序):");
for (Fruit f : fruits) {
double sim = cosineSimilarity(queryVector, f.embedding);
System.out.printf(" %s (%.4f): %s%n", f.name, sim, f.description);
}
}
/**
* 计算两个向量的余弦相似度
*/
public static double cosineSimilarity(double[] vectorA, double[] vectorB) {
if (vectorA.length != vectorB.length) {
throw new IllegalArgumentException("向量维度必须相同");
}
double dotProduct = 0;
double normA = 0;
double normB = 0;
for (int i = 0; i
@Slf4j
public class EmbeddingClientImpl implements EmbeddingClient {
private final RestTemplate restTemplate;
private final String address;
private final String key;
public EmbeddingClientImpl(String address, String key) {
PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager();
connectionManager.setMaxTotal(100);
connectionManager.setDefaultMaxPerRoute(20);
// 设置请求配置
RequestConfig requestConfig = RequestConfig.custom()
.setConnectionRequestTimeout(Timeout.ofSeconds(30))
.setResponseTimeout(Timeout.ofSeconds(300)) // 5分钟响应超时
.build();
// 使用 HttpClientBuilder 来构建 HttpClient
HttpClient httpClient = HttpClientBuilder.create()
.setConnectionManager(connectionManager)
.setDefaultRequestConfig(requestConfig)
.build();
// 创建 HttpComponentsClientHttpRequestFactory
HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
requestFactory.setConnectTimeout(30000); // 30秒连接超时
requestFactory.setConnectionRequestTimeout(30000);
// 创建 RestTemplate,只使用 StringHttpMessageConverter 避免 Jackson 依赖问题
this.restTemplate = new RestTemplate(requestFactory);
// 清除默认的消息转换器,只保留字符串转换器
this.restTemplate.setMessageConverters(
Collections.singletonList(new StringHttpMessageConverter(StandardCharsets.UTF_8)));
this.address = address;
this.key = key;
}
@Override
public String embedding(String model, String input) {
long start = System.currentTimeMillis();
String url = address;
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setAcceptCharset(Collections.singletonList(StandardCharsets.UTF_8));
if (key != null && !key.isEmpty()) {
headers.add("Authorization", "Bearer " + key);
}
// 将 request 转化为 body 字符串
JSONObject jsonObject = new JSONObject();
jsonObject.put("input", input);
jsonObject.put("model", model);
String body = jsonObject.toString();
log.debug("Embedding Request Body: {}", body);
// 请求
HttpEntity req = new HttpEntity(body, headers);
ResponseEntity result = restTemplate.postForEntity(url, req, String.class);
if (!result.getStatusCode().equals(HttpStatus.OK)) {
throw new RuntimeException("embeddings error, request: " + body + ", response: " + result.getBody());
}
log.info("embedding cost {} ms", System.currentTimeMillis() - start);
return result.getBody();
}
/**
* 获取文本嵌入向量
*
* 解析 OpenAI 格式的响应,提取 embedding 向量
*
* 响应格式示例:
* {
* "object": "list",
* "data": [{
* "object": "embedding",
* "index": 0,
* "embedding": [0.0023064255, -0.009327292, ...]
* }],
* "model": "text-embedding-ada-002",
* "usage": {"prompt_tokens": 8, "total_tokens": 8}
* }
*
* @param model 模型名称
* @param input 输入文本
* @return 嵌入向量
*/
@Override
public double[] getEmbeddingVector(String model, String input) {
String response = embedding(model, input);
return parseEmbeddingVector(response);
}
/**
* 解析嵌入向量响应
* @param response JSON响应字符串
* @return 向量数组
*/
private double[] parseEmbeddingVector(String response) {
try {
JSONObject jsonResponse = JSONObject.parseObject(response);
// OpenAI 格式
if (jsonResponse.containsKey("data")) {
JSONArray dataArray = jsonResponse.getJSONArray("data");
if (dataArray != null && !dataArray.isEmpty()) {
JSONObject firstData = dataArray.getJSONObject(0);
JSONArray embeddingArray = firstData.getJSONArray("embedding");
return jsonArrayToDoubleArray(embeddingArray);
}
}
// Ollama 格式 (直接返回 embedding 数组)
if (jsonResponse.containsKey("embedding")) {
JSONArray embeddingArray = jsonResponse.getJSONArray("embedding");
return jsonArrayToDoubleArray(embeddingArray);
}
// 阿里通义格式
if (jsonResponse.containsKey("output")) {
JSONObject output = jsonResponse.getJSONObject("output");
if (output.containsKey("embeddings")) {
JSONArray embeddings = output.getJSONArray("embeddings");
if (!embeddings.isEmpty()) {
JSONObject firstEmbedding = embeddings.getJSONObject(0);
JSONArray embeddingArray = firstEmbedding.getJSONArray("embedding");
return jsonArrayToDoubleArray(embeddingArray);
}
}
}
throw new RuntimeException("无法解析嵌入向量响应: " + response);
}
catch (Exception e) {
log.error("解析嵌入向量失败: {}", response, e);
throw new RuntimeException("解析嵌入向量失败", e);
}
}
/**
* 将 JSONArray 转换为 double 数组
*/
private double[] jsonArrayToDoubleArray(JSONArray jsonArray) {
double[] result = new double[jsonArray.size()];
for (int i = 0; i

登录查看全部
参与评论
手机查看
返回顶部