第一章:chat 初体验
- 作者:影子, Spring AI Alibaba Committer
- 本文档基于 Spring AI 1.0.0 版本,Spring AI Alibaba 1.0.0.2 版本
- 本章包含:chat快速上手 + 源码解读(ChatClient + ChatModel 自动注入、ChatClient 调用链路)
chat快速上手
通过自然语言的句子和 AI 模型进行会话交流,以下实现了 chat 的典型案例:Call、Stream,ChatOptions设置。实战代码可见:https://github.com/GTyingzi/spring-ai-tutorial下的chat
pom 文件
<dependencies> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency>
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-openai</artifactId> </dependency>
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-chat-client</artifactId> </dependency>
</dependencies>
application.yml
server: port: 8080
spring: application: name: advisor-base
ai: openai: api-key: ${DASHSCOPEAPIKEY} base-url: https://dashscope.aliyuncs.com/compatible-mode chat: options: model: qwen-max
OPENAI 由于封禁的原因,国内无法很好的获取其 api-key,国内厂商阿里的百炼可进行平替,只需要替换对应的 api-key、base-url 即可,同时可选对应的模型
controller
ChatController
package com.spring.ai.tutorial.chat.controller;
import org.springframework.ai.chat.client.ChatClient;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.RequestParam;import org.springframework.web.bind.annotation.RestController;import reactor.core.publisher.Flux;
@RestController@RequestMapping("/chat")public class ChatController {
private final ChatClient chatClient;
public ChatController(ChatClient.Builder builder) { this.chatClient = builder.build(); }
@GetMapping("/call") public String call(@RequestParam(value = "query", defaultValue = "你好,很高兴认识你,能简单介绍一下自己吗?")String query) { return chatClient.prompt(query).call().content(); }
@GetMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux<String> stream(@RequestParam(value = "query", defaultValue = "你好,很高兴认识你,能简单介绍一下自己吗?")String query) { return chatClient.prompt(query).stream().content(); }}
效果
call 调用
stream 调用
ChatOptionController
package com.spring.ai.tutorial.chat.controller;
import org.springframework.ai.chat.client.ChatClient;import org.springframework.ai.openai.OpenAiChatOptions;import org.springframework.web.bind.annotation.GetMapping;import org.springframework.web.bind.annotation.RequestMapping;import org.springframework.web.bind.annotation.RequestParam;import org.springframework.web.bind.annotation.RestController;
/** * @author yingzi * @date 2025/5/24 16:52 */@RestController@RequestMapping("/chat/option")public class ChatOptionController {
private final ChatClient chatClient;
public ChatOptionController(ChatClient.Builder builder) { this.chatClient = builder .defaultOptions( OpenAiChatOptions.builder() .temperature(0.9) .build() ) .build(); }
@GetMapping("/call") public String call(@RequestParam(value = "query", defaultValue = "你好,请为我创造一首以“影子”为主题的诗")String query) { return chatClient.prompt(query).call().content(); }
@GetMapping("/call/temperature") public String callOption(@RequestParam(value = "query", defaultValue = "你好,请为我创造一首以“影子”为主题的诗")String query) { return chatClient.prompt(query) .options( OpenAiChatOptions.builder() .temperature(0.0) .build() ) .call().content(); }}
chatClient 全局配置 temperature=0.9
- /call:使用的是 temperature=0.9
- /call/temperature:当前请求覆盖配置,temperature=0.0
效果
/call 接口的 temperature=0.9
/call/temperature 接口的 temperature=0.0
ChatClient + ChatModel 自动注入篇
[!TIP] 配置 pom 文件后,自动注入 ChatModel、ChatClient.Builder 的原理
pom.xml 文件
入 ChatClient 依赖
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-chat-client</artifactId></dependency>
选择 chat 模型,这里使用 openai
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-openai</artifactId></dependency>
ChatClient 自动注入
ChatClientBuilderProperties
类的作用:
- 控制是否提供 ChatClient.Builder 聊天客户端构建器的 Bean,默认为 true
- 配置观测日志的行为,如是否记录提示词内容
package org.springframework.ai.model.chat.client.autoconfigure;
import org.springframework.boot.context.properties.ConfigurationProperties;
@ConfigurationProperties("spring.ai.chat.client")public class ChatClientBuilderProperties { public static final String CONFIGPREFIX = "spring.ai.chat.client"; private boolean enabled = true; private final Observations observations = new Observations();
public Observations getObservations() { return this.observations; }
public boolean isEnabled() { return this.enabled; }
public void setEnabled(boolean enabled) { this.enabled = enabled; }
public static class Observations { private boolean logPrompt = false;
public boolean isLogPrompt() { return this.logPrompt; }
public void setLogPrompt(boolean logPrompt) { this.logPrompt = logPrompt; } }}
ChatClientBuilderConfigurer
类的作用:
- 用于对 ChatClient.Builder 聊天客户端构建器进行扩展性配置
- 通过注册不同的 ChatClientCustomizer 实现,可动态调整聊天客户端
package org.springframework.ai.model.chat.client.autoconfigure;
import java.util.List;import org.springframework.ai.chat.client.ChatClient;import org.springframework.ai.chat.client.ChatClientCustomizer;
public class ChatClientBuilderConfigurer { private List<ChatClientCustomizer> customizers;
void setChatClientCustomizers(List<ChatClientCustomizer> customizers) { this.customizers = customizers; }
public ChatClient.Builder configure(ChatClient.Builder builder) { this.applyCustomizers(builder); return builder; }
private void applyCustomizers(ChatClient.Builder builder) { if (this.customizers != null) { for(ChatClientCustomizer customizer : this.customizers) { customizer.customize(builder); } }
}}
ChatClientCustomizer
可通过实现 ChatClientCustomizer 函数式接口,自定义调整 ChatClient.Builder 的相关配置
package org.springframework.ai.chat.client;
@FunctionalInterfacepublic interface ChatClientCustomizer { void customize(ChatClient.Builder chatClientBuilder);}
ChatClientAutoConfiguration
类上重点注解说明
- 在 ObservationAutoConfiguration 类之后加载,确保观测基础设施已就绪
- 当类路径 ChatClient 类时才启用该自动配置
- 启用 ChatClientBuilderProperties 配置属性的支持,将配置文件中的
spring.ai.chat.client.*
映射到该类实例 - 只有当配置项
spring.ai.chat.client.enabled=true
时,才启用该自动配置,默认为 true
对外提供 Bean
-
ChatClientBuilderConfigurer:从容器中获取所有 ChatClientCustomizer 实例,配置 ChatClient.Builder 信息
-
ChatClient.Builder:使用 ChatModel 初始化 ChatClient.Builder,再
- @Scope(“prototype”):每次注入都会生成新实例
内部配置配 TracerPresentObservationConfiguration、TracerNotPresentObservationConfiguration
-
配置项:
spring.ai.chat.client.observations.log-prompt=true
-
当项目中存在 Tracer 时,启用带追踪能力的日志记录处理器
- 注册带有追踪能力的日志处理器,用于记录提示词内容,输出安全警告日志
-
当项目中不存在 Tracer 时,启用普通日志处理器
- 未使用追踪框架的情况下,仅记录提示词内容,输出安全警告日志
package org.springframework.ai.model.chat.client.autoconfigure;
import io.micrometer.observation.ObservationRegistry;import io.micrometer.tracing.Tracer;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.ai.chat.client.ChatClient;import org.springframework.ai.chat.client.ChatClientCustomizer;import org.springframework.ai.chat.client.observation.ChatClientObservationContext;import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;import org.springframework.ai.chat.client.observation.ChatClientPromptContentObservationHandler;import org.springframework.ai.chat.model.ChatModel;import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;import org.springframework.beans.factory.ObjectProvider;import org.springframework.boot.autoconfigure.AutoConfiguration;import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;import org.springframework.boot.context.properties.EnableConfigurationProperties;import org.springframework.context.annotation.Bean;import org.springframework.context.annotation.Configuration;import org.springframework.context.annotation.Scope;
@AutoConfiguration( afterName = {"org.springframework.boot.actuate.autoconfigure.observation.ObservationAutoConfiguration"})@ConditionalOnClass({ChatClient.class})@EnableConfigurationProperties({ChatClientBuilderProperties.class})@ConditionalOnProperty( prefix = "spring.ai.chat.client", name = {"enabled"}, havingValue = "true", matchIfMissing = true)public class ChatClientAutoConfiguration { private static final Logger logger = LoggerFactory.getLogger(ChatClientAutoConfiguration.class);
private static void logPromptContentWarning() { logger.warn("You have enabled logging out the ChatClient prompt content with the risk of exposing sensitive or private information. Please, be careful!"); }
@Bean @ConditionalOnMissingBean ChatClientBuilderConfigurer chatClientBuilderConfigurer(ObjectProvider<ChatClientCustomizer> customizerProvider) { ChatClientBuilderConfigurer configurer = new ChatClientBuilderConfigurer(); configurer.setChatClientCustomizers(customizerProvider.orderedStream().toList()); return configurer; }
@Bean @Scope("prototype") @ConditionalOnMissingBean ChatClient.Builder chatClientBuilder(ChatClientBuilderConfigurer chatClientBuilderConfigurer, ChatModel chatModel, ObjectProvider<ObservationRegistry> observationRegistry, ObjectProvider<ChatClientObservationConvention> observationConvention) { ChatClient.Builder builder = ChatClient.builder(chatModel, (ObservationRegistry)observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP), (ChatClientObservationConvention)observationConvention.getIfUnique(() -> null)); return chatClientBuilderConfigurer.configure(builder); }
@Configuration( proxyBeanMethods = false ) @ConditionalOnClass({Tracer.class}) @ConditionalOnBean({Tracer.class}) static class TracerPresentObservationConfiguration { @Bean @ConditionalOnMissingBean( value = {ChatClientPromptContentObservationHandler.class}, name = {"chatClientPromptContentObservationHandler"} ) @ConditionalOnProperty( prefix = "spring.ai.chat.client.observations", name = {"log-prompt"}, havingValue = "true" ) TracingAwareLoggingObservationHandler<ChatClientObservationContext> chatClientPromptContentObservationHandler(Tracer tracer) { ChatClientAutoConfiguration.logPromptContentWarning(); return new TracingAwareLoggingObservationHandler(new ChatClientPromptContentObservationHandler(), tracer); } }
@Configuration( proxyBeanMethods = false ) @ConditionalOnMissingClass({"io.micrometer.tracing.Tracer"}) static class TracerNotPresentObservationConfiguration { @Bean @ConditionalOnMissingBean @ConditionalOnProperty( prefix = "spring.ai.chat.client.observations", name = {"log-prompt"}, havingValue = "true" ) ChatClientPromptContentObservationHandler chatClientPromptContentObservationHandler() { ChatClientAutoConfiguration.logPromptContentWarning(); return new ChatClientPromptContentObservationHandler(); } }}
ChatModel 自动注入
OpenAiParentProperties
从 OpenAI 的开发者平台获取,基础配置信息
- apiKey(必填):密钥
- baseUrl(选填):调用 url,若没填会自动填充,详情可见 OpenAiConnectionProperties 类的DEFAULTBASEURL字段
- projectId(选填):项目 Id
- organizationId(选填):组织 Id
package org.springframework.ai.model.openai.autoconfigure;
class OpenAiParentProperties { private String apiKey; private String baseUrl; private String projectId; private String organizationId;
public String getApiKey() { return this.apiKey; }
public void setApiKey(String apiKey) { this.apiKey = apiKey; }
public String getBaseUrl() { return this.baseUrl; }
public void setBaseUrl(String baseUrl) { this.baseUrl = baseUrl; }
public String getProjectId() { return this.projectId; }
public void setProjectId(String projectId) { this.projectId = projectId; }
public String getOrganizationId() { return this.organizationId; }
public void setOrganizationId(String organizationId) { this.organizationId = organizationId; }}
OpenAiConnectionProperties
Connection 配置类,默认 baseUrl 为DEFAULTBASEURL,若配置文件有 baseUrl 配置则会覆盖
package org.springframework.ai.model.openai.autoconfigure;
import org.springframework.boot.context.properties.ConfigurationProperties;
@ConfigurationProperties("spring.ai.openai")public class OpenAiConnectionProperties extends OpenAiParentProperties { public static final String CONFIGPREFIX = "spring.ai.openai"; public static final String DEFAULTBASEURL = "https://api.openai.com";
public OpenAiConnectionProperties() { super.setBaseUrl("https://api.openai.com"); }}
OpenAiChatProperties
Chat 配置类。
-
配置 Chat Model,默认为“gpt-4o-mini”
-
配置 Chat 接口路径,默认为“/v1/chat/completions”
-
配置 temperature,默认为 0.7(值范围一般在 0~1,部分模型会大于 1)
- 值越低输出越确定(0,代表每次相同输入产生相同输出)
- 值越高随机性越强(产生更开放或不常见的回答,适用于创意写作等场景)
package org.springframework.ai.model.openai.autoconfigure;
import org.springframework.ai.openai.OpenAiChatOptions;import org.springframework.boot.context.properties.ConfigurationProperties;import org.springframework.boot.context.properties.NestedConfigurationProperty;
@ConfigurationProperties("spring.ai.openai.chat")public class OpenAiChatProperties extends OpenAiParentProperties { public static final String CONFIGPREFIX = "spring.ai.openai.chat"; public static final String DEFAULTCHATMODEL = "gpt-4o-mini"; public static final String DEFAULTCOMPLETIONSPATH = "/v1/chat/completions"; private static final Double DEFAULTTEMPERATURE = 0.7; private String completionsPath = "/v1/chat/completions"; @NestedConfigurationProperty private OpenAiChatOptions options;
public OpenAiChatProperties() { this.options = OpenAiChatOptions.builder().model("gpt-4o-mini").temperature(DEFAULTTEMPERATURE).build(); }
public OpenAiChatOptions getOptions() { return this.options; }
public void setOptions(OpenAiChatOptions options) { this.options = options; }
public String getCompletionsPath() { return this.completionsPath; }
public void setCompletionsPath(String completionsPath) { this.completionsPath = completionsPath; }}
OpenAiChatAutoConfiguration
类上重点注解说明
- 确保网络客户端(RestClient、WebClient)、重试机制、工具调用就绪后再注入
- 当类路径有 OpenAiApi 类时才启用该自动配置
- 启用 OpenAiConnectionProperties、OpenAiChatProperties 配置属性的支持
- 只有当配置项
spring.ai.model.chat.openai=true
时,才会启用该自动配置,默认为 true
对外提供了 OpenAiChatModel 的 Bean
-
使用 openAiApi 方法构建底层 API 实例,通过 OpenAiChatModel.builder()构建 Chat 模型,另外配置了默认选项、工具调用、重试策略、观测注册表
-
openAiApi 侧封装了 OpenAI API 的构建逻辑,包括基础 URL、API Key、请求头、请求路径、HTTP 客户端等配置
- 注:非公开 Bean
package org.springframework.ai.model.openai.autoconfigure;
import io.micrometer.observation.ObservationRegistry;import java.util.Objects;import org.springframework.ai.chat.observation.ChatModelObservationConvention;import org.springframework.ai.model.SimpleApiKey;import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;import org.springframework.ai.model.tool.ToolCallingManager;import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration;import org.springframework.ai.openai.OpenAiChatModel;import org.springframework.ai.openai.api.OpenAiApi;import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;import org.springframework.beans.factory.ObjectProvider;import org.springframework.boot.autoconfigure.AutoConfiguration;import org.springframework.boot.autoconfigure.ImportAutoConfiguration;import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration;import org.springframework.boot.context.properties.EnableConfigurationProperties;import org.springframework.context.annotation.Bean;import org.springframework.retry.support.RetryTemplate;import org.springframework.web.client.ResponseErrorHandler;import org.springframework.web.client.RestClient;import org.springframework.web.reactive.function.client.WebClient;
@AutoConfiguration( after = {RestClientAutoConfiguration.class, WebClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class})@ConditionalOnClass({OpenAiApi.class})@EnableConfigurationProperties({OpenAiConnectionProperties.class, OpenAiChatProperties.class})@ConditionalOnProperty( name = {"spring.ai.model.chat"}, havingValue = "openai", matchIfMissing = true)@ImportAutoConfiguration( classes = {SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class, WebClientAutoConfiguration.class, ToolCallingAutoConfiguration.class})public class OpenAiChatAutoConfiguration { @Bean @ConditionalOnMissingBean public OpenAiChatModel openAiChatModel(OpenAiConnectionProperties commonProperties, OpenAiChatProperties chatProperties, ObjectProvider<RestClient.Builder> restClientBuilderProvider, ObjectProvider<WebClient.Builder> webClientBuilderProvider, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, ObjectProvider<ObservationRegistry> observationRegistry, ObjectProvider<ChatModelObservationConvention> observationConvention, ObjectProvider<ToolExecutionEligibilityPredicate> openAiToolExecutionEligibilityPredicate) { OpenAiApi openAiApi = this.openAiApi(chatProperties, commonProperties, (RestClient.Builder)restClientBuilderProvider.getIfAvailable(RestClient::builder), (WebClient.Builder)webClientBuilderProvider.getIfAvailable(WebClient::builder), responseErrorHandler, "chat"); OpenAiChatModel chatModel = OpenAiChatModel.builder().openAiApi(openAiApi).defaultOptions(chatProperties.getOptions()).toolCallingManager(toolCallingManager).toolExecutionEligibilityPredicate((ToolExecutionEligibilityPredicate)openAiToolExecutionEligibilityPredicate.getIfUnique(DefaultToolExecutionEligibilityPredicate::new)).retryTemplate(retryTemplate).observationRegistry((ObservationRegistry)observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)).build(); Objects.requireNonNull(chatModel); observationConvention.ifAvailable(chatModel::setObservationConvention); return chatModel; }
private OpenAiApi openAiApi(OpenAiChatProperties chatProperties, OpenAiConnectionProperties commonProperties, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler, String modelType) { OpenAIAutoConfigurationUtil.ResolvedConnectionProperties resolved = OpenAIAutoConfigurationUtil.resolveConnectionProperties(commonProperties, chatProperties, modelType); return OpenAiApi.builder().baseUrl(resolved.baseUrl()).apiKey(new SimpleApiKey(resolved.apiKey())).headers(resolved.headers()).completionsPath(chatProperties.getCompletionsPath()).embeddingsPath("/v1/embeddings").restClientBuilder(restClientBuilder).webClientBuilder(webClientBuilder).responseErrorHandler(responseErrorHandler).build(); }}
工具类:OpenAIAutoConfigurationUtil
- 校验 apiKey、baseUrl 最后拼接到 OpenAiApi 时不为空
- 根据 projectId、organizationId 设置请求头
package org.springframework.ai.model.openai.autoconfigure;
import java.util.HashMap;import java.util.List;import java.util.Map;import org.jetbrains.annotations.NotNull;import org.springframework.util.Assert;import org.springframework.util.CollectionUtils;import org.springframework.util.MultiValueMap;import org.springframework.util.StringUtils;
public final class OpenAIAutoConfigurationUtil { private OpenAIAutoConfigurationUtil() { }
@NotNull public static ResolvedConnectionProperties resolveConnectionProperties(OpenAiParentProperties commonProperties, OpenAiParentProperties modelProperties, String modelType) { String baseUrl = StringUtils.hasText(modelProperties.getBaseUrl()) ? modelProperties.getBaseUrl() : commonProperties.getBaseUrl(); String apiKey = StringUtils.hasText(modelProperties.getApiKey()) ? modelProperties.getApiKey() : commonProperties.getApiKey(); String projectId = StringUtils.hasText(modelProperties.getProjectId()) ? modelProperties.getProjectId() : commonProperties.getProjectId(); String organizationId = StringUtils.hasText(modelProperties.getOrganizationId()) ? modelProperties.getOrganizationId() : commonProperties.getOrganizationId(); Map<String, List<String>> connectionHeaders = new HashMap(); if (StringUtils.hasText(projectId)) { connectionHeaders.put("OpenAI-Project", List.of(projectId)); }
if (StringUtils.hasText(organizationId)) { connectionHeaders.put("OpenAI-Organization", List.of(organizationId)); }
Assert.hasText(baseUrl, "OpenAI base URL must be set. Use the connection property: spring.ai.openai.base-url or spring.ai.openai." + modelType + ".base-url property."); Assert.hasText(apiKey, "OpenAI API key must be set. Use the connection property: spring.ai.openai.api-key or spring.ai.openai." + modelType + ".api-key property."); return new ResolvedConnectionProperties(baseUrl, apiKey, CollectionUtils.toMultiValueMap(connectionHeaders)); }
public static record ResolvedConnectionProperties(String baseUrl, String apiKey, MultiValueMap<String, String> headers) { }}
ChatClient 解读
ChatClient 端设置 advisors、ChatOptions、用户提示信息、系统提示信息、工具等信息,构建 DefaultChatClient.DefaultChatClientRequestSpec,再利用 DefaultChatClientUtils 将其转换为 ChatClientRequest
AdvisorChain 链调用一系列的增强器Advisor,每个增强器输入是 ChatClientRequest,输出 ChatClientResponse(其中必定会用到的是 ChatModelCallAdvisor 或 ChatModelStreamAdvisor)
ChatClient
类的说明:面向对话式 AI 模型的客户端接口,提供了系列的 API 与 AI 会话模型交互,该接口封装了请求构建、调用、响应处理等流畅,支持同步、流式调用
方法说明
方法名称 | 描述 |
create(静态方法) | 由ChatModel、观测信息等创建 ChatClient 实例 |
builder(静态方法) | 由ChatModel、观测信息等创建 ChatClient.Builder实例 |
mutate | 复制当前客户端配置,生成新的ChatClient.Builder实例 |
prompt | 构建ChatClientRequestSpec实例 |
内部接口类说明
接口类 | 方法名称 | 描述 |
Builder (全局的ChatClient配置) | defaultAdvisors | 设置advisors |
defaultOptions | 设置ChatOptions | |
defaultUser | 设置用户提示信息 | |
defaultSystem | 设置系统提示信息 | |
defaultTemplateRenderer | 设置模版渲染器,用于处理字符串的占位符 | |
defaultToolNames | 根据工具名称获取工具配置 | |
defaultTools | 根据实例获取工具配置 | |
defaultToolCallbacks | 根据ToolCallback获取工具配置 | |
defaultToolContext | 工具的上下文 | |
clone | 复制当前客户端配置,生成新的ChatClient.Builder实例 | |
build | 构建最终的ChatClient实例 | |
ChatClientRequestSpec (当前的ChatClient配置) | advisors | 设置advisors |
options | 设置ChatOptions | |
user | 设置用户提示信息 | |
system | 设置系统提示信息 | |
templateRenderer | 设置模版渲染器,用于处理字符串的占位符 | |
toolNames | 根据工具名称获取工具配置 | |
tools | 根据实例获取工具配置 | |
toolCallbacks | 根据ToolCallback获取工具配置 | |
toolContext | 工具的上下文 | |
mutate | 复制当前客户端配置,生成新的ChatClient.Builder实例 | |
messages | 添加Message | |
call | 同步调用 | |
stream | 流式调用 | |
PromptUserSpec (用户提示信息的构建规范) | text | 设置用户文本内容 |
param | 设置参数 | |
params | 设置参数 | |
media | 设置多媒体内容(如图片) | |
PromptSystemSpec (系统提示信息的构建规范) | text | 设置系统指令 |
param | 设置参数 | |
params | 设置参数 | |
AdvisorSpec (设置增强器) | param | 设置增强器中会用到的一些参数配置 |
advisors | 添加增强器 | |
CallResponseSpec | entity | 将响应体转换为指定类型 |
chatClientResponse | 原始响应对象+请求时的上下文内容 | |
chatResponse | 原始的响应对象 | |
content | 响应的文本内容 | |
responseEntity | 获取封装了响应头和body的对象 | |
chatClientResponse | 流式的原始响应对象+请求时的上下文内容 | |
StreamResponseSpec | chatResponse | 流式的原始的响应对象 |
content | 流式的响应的文本内容 |
public interface ChatClient { static ChatClient create(ChatModel chatModel) { return create(chatModel, ObservationRegistry.NOOP); }
static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry) { return create(chatModel, observationRegistry, (ChatClientObservationConvention)null); }
static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention) { Assert.notNull(chatModel, "chatModel cannot be null"); Assert.notNull(observationRegistry, "observationRegistry cannot be null"); return builder(chatModel, observationRegistry, observationConvention).build(); }
static Builder builder(ChatModel chatModel) { return builder(chatModel, ObservationRegistry.NOOP, (ChatClientObservationConvention)null); }
static Builder builder(ChatModel chatModel, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention customObservationConvention) { Assert.notNull(chatModel, "chatModel cannot be null"); Assert.notNull(observationRegistry, "observationRegistry cannot be null"); return new DefaultChatClientBuilder(chatModel, observationRegistry, customObservationConvention); }
ChatClientRequestSpec prompt();
ChatClientRequestSpec prompt(String content);
ChatClientRequestSpec prompt(Prompt prompt);
Builder mutate();
public interface AdvisorSpec { AdvisorSpec param(String k, Object v);
AdvisorSpec params(Map<String, Object> p);
AdvisorSpec advisors(Advisor... advisors);
AdvisorSpec advisors(List<Advisor> advisors); }
public interface Builder { Builder defaultAdvisors(Advisor... advisor);
Builder defaultAdvisors(Consumer<AdvisorSpec> advisorSpecConsumer);
Builder defaultAdvisors(List<Advisor> advisors);
Builder defaultOptions(ChatOptions chatOptions);
Builder defaultUser(String text);
Builder defaultUser(Resource text, Charset charset);
Builder defaultUser(Resource text);
Builder defaultUser(Consumer<PromptUserSpec> userSpecConsumer);
Builder defaultSystem(String text);
Builder defaultSystem(Resource text, Charset charset);
Builder defaultSystem(Resource text);
Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer);
Builder defaultTemplateRenderer(TemplateRenderer templateRenderer);
Builder defaultToolNames(String... toolNames);
Builder defaultTools(Object... toolObjects);
Builder defaultToolCallbacks(ToolCallback... toolCallbacks);
Builder defaultToolCallbacks(List<ToolCallback> toolCallbacks);
Builder defaultToolCallbacks(ToolCallbackProvider... toolCallbackProviders);
Builder defaultToolContext(Map<String, Object> toolContext);
Builder clone();
ChatClient build(); }
public interface CallPromptResponseSpec { String content();
List<String> contents();
ChatResponse chatResponse(); }
public interface CallResponseSpec { @Nullable <T> T entity(ParameterizedTypeReference<T> type);
@Nullable <T> T entity(StructuredOutputConverter<T> structuredOutputConverter);
@Nullable <T> T entity(Class<T> type);
ChatClientResponse chatClientResponse();
@Nullable ChatResponse chatResponse();
@Nullable String content();
<T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type);
<T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type);
<T> ResponseEntity<ChatResponse, T> responseEntity(StructuredOutputConverter<T> structuredOutputConverter); }
public interface ChatClientRequestSpec { Builder mutate();
ChatClientRequestSpec advisors(Consumer<AdvisorSpec> consumer);
ChatClientRequestSpec advisors(Advisor... advisors);
ChatClientRequestSpec advisors(List<Advisor> advisors);
ChatClientRequestSpec messages(Message... messages);
ChatClientRequestSpec messages(List<Message> messages);
<T extends ChatOptions> ChatClientRequestSpec options(T options);
ChatClientRequestSpec toolNames(String... toolNames);
ChatClientRequestSpec tools(Object... toolObjects);
ChatClientRequestSpec toolCallbacks(ToolCallback... toolCallbacks);
ChatClientRequestSpec toolCallbacks(List<ToolCallback> toolCallbacks);
ChatClientRequestSpec toolCallbacks(ToolCallbackProvider... toolCallbackProviders);
ChatClientRequestSpec toolContext(Map<String, Object> toolContext);
ChatClientRequestSpec system(String text);
ChatClientRequestSpec system(Resource textResource, Charset charset);
ChatClientRequestSpec system(Resource text);
ChatClientRequestSpec system(Consumer<PromptSystemSpec> consumer);
ChatClientRequestSpec user(String text);
ChatClientRequestSpec user(Resource text, Charset charset);
ChatClientRequestSpec user(Resource text);
ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);
ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer);
CallResponseSpec call();
StreamResponseSpec stream(); }
public interface PromptSystemSpec { PromptSystemSpec text(String text);
PromptSystemSpec text(Resource text, Charset charset);
PromptSystemSpec text(Resource text);
PromptSystemSpec params(Map<String, Object> p);
PromptSystemSpec param(String k, Object v); }
public interface PromptUserSpec { PromptUserSpec text(String text);
PromptUserSpec text(Resource text, Charset charset);
PromptUserSpec text(Resource text);
PromptUserSpec params(Map<String, Object> p);
PromptUserSpec param(String k, Object v);
PromptUserSpec media(Media... media);
PromptUserSpec media(MimeType mimeType, URL url);
PromptUserSpec media(MimeType mimeType, Resource resource); }
public interface StreamPromptResponseSpec { Flux<ChatResponse> chatResponse();
Flux<String> content(); }
public interface StreamResponseSpec { Flux<ChatClientResponse> chatClientResponse();
Flux<ChatResponse> chatResponse();
Flux<String> content(); }}
DefaultChatClient
ChatClient 接口的默认实现类,用于构建和执行与 AI 聊天模型交互的请求
- 内部类 DefaultChatClientRequestSpec 实现了 ChatClient.ChatClientRequestSpec:新增 ChatModelCallAdvisor
public static class DefaultChatClientRequestSpec implements ChatClient.ChatClientRequestSpec {
private BaseAdvisorChain buildAdvisorChain() { this.advisors.add(ChatModelCallAdvisor.builder().chatModel(this.chatModel).build()); this.advisors.add(ChatModelStreamAdvisor.builder().chatModel(this.chatModel).build()); return DefaultAroundAdvisorChain.builder(this.observationRegistry).pushAll(this.advisors).templateRenderer(this.templateRenderer).build(); } }
- 内部类 DefaultPromptSystemSpec 实现 ChatClient.PromptSystemSpec:设置用户文本内容、参数
- 内部类 DefaultPromptSystemSpec 实现 ChatClient.PromptSystemSpec:设置系统文本内容、参数
- 内部类 DefaultAdvisorSpec 实现 ChatClient.AdvisorSpec:设置 Advisor,及其 advisor 中用到的参数
- 内部类 DefaultCallResponseSpec 实现 ChatClient.CallResponseSpec:通过 doGetObservableChatClientResponse 方法发起请求,调用一系列的 BaseAdvisorChain
public static class DefaultCallResponseSpec implements ChatClient.CallResponseSpec { private final ChatClientRequest request; private final BaseAdvisorChain advisorChain; private final ObservationRegistry observationRegistry; private final ChatClientObservationConvention observationConvention;
private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest, @Nullable String outputFormat) { if (outputFormat != null) { chatClientRequest.context().put(ChatClientAttributes.OUTPUTFORMAT.getKey(), outputFormat); }
ChatClientObservationContext observationContext = ChatClientObservationContext.builder().request(chatClientRequest).advisors(this.advisorChain.getCallAdvisors()).stream(false).format(outputFormat).build(); Observation observation = ChatClientObservationDocumentation.AICHATCLIENT.observation(this.observationConvention, DefaultChatClient.DEFAULTCHATCLIENTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); ChatClientResponse chatClientResponse = (ChatClientResponse)observation.observe(() -> this.advisorChain.nextCall(chatClientRequest)); return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build(); }}
- 内部类 DefaultStreamResponseSpec 实现 ChatClient.StreamResponseSpec:通过 doGetObservableFluxChatResponse 方法发起请求,调用一系列的 BaseAdvisorChain
public static class DefaultStreamResponseSpec implements ChatClient.StreamResponseSpec { private final ChatClientRequest request; private final BaseAdvisorChain advisorChain; private final ObservationRegistry observationRegistry; private final ChatClientObservationConvention observationConvention;
private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) { return Flux.deferContextual((contextView) -> { ChatClientObservationContext observationContext = ChatClientObservationContext.builder().request(chatClientRequest).advisors(this.advisorChain.getStreamAdvisors()).stream(true).build(); Observation observation = ChatClientObservationDocumentation.AICHATCLIENT.observation(this.observationConvention, DefaultChatClient.DEFAULTCHATCLIENTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); observation.parentObservation((Observation)contextView.getOrDefault("micrometer.observation", (Object)null)).start(); Flux var10000 = this.advisorChain.nextStream(chatClientRequest); Objects.requireNonNull(observation); return var10000.doOnError(observation::error).doFinally((s) -> observation.stop()).contextWrite((ctx) -> ctx.put("micrometer.observation", observation)); }); }}
完整代码如下
package org.springframework.ai.chat.client;
public class DefaultChatClient implements ChatClient { private static final ChatClientObservationConvention DEFAULTCHATCLIENTOBSERVATIONCONVENTION = new DefaultChatClientObservationConvention(); private static final TemplateRenderer DEFAULTTEMPLATERENDERER = StTemplateRenderer.builder().build(); private final DefaultChatClientRequestSpec defaultChatClientRequest;
public DefaultChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) { Assert.notNull(defaultChatClientRequest, "defaultChatClientRequest cannot be null"); this.defaultChatClientRequest = defaultChatClientRequest; }
public ChatClient.ChatClientRequestSpec prompt() { return new DefaultChatClientRequestSpec(this.defaultChatClientRequest); }
public ChatClient.ChatClientRequestSpec prompt(String content) { Assert.hasText(content, "content cannot be null or empty"); return this.prompt(new Prompt(content)); }
public ChatClient.ChatClientRequestSpec prompt(Prompt prompt) { Assert.notNull(prompt, "prompt cannot be null"); DefaultChatClientRequestSpec spec = new DefaultChatClientRequestSpec(this.defaultChatClientRequest); if (prompt.getOptions() != null) { spec.options(prompt.getOptions()); }
if (prompt.getInstructions() != null) { spec.messages(prompt.getInstructions()); }
return spec; }
public ChatClient.Builder mutate() { return this.defaultChatClientRequest.mutate(); }
public static class DefaultPromptUserSpec implements ChatClient.PromptUserSpec { private final Map<String, Object> params = new HashMap(); private final List<Media> media = new ArrayList(); @Nullable private String text;
public ChatClient.PromptUserSpec media(Media... media) { Assert.notNull(media, "media cannot be null"); Assert.noNullElements(media, "media cannot contain null elements"); this.media.addAll(Arrays.asList(media)); return this; }
public ChatClient.PromptUserSpec media(MimeType mimeType, URL url) { Assert.notNull(mimeType, "mimeType cannot be null"); Assert.notNull(url, "url cannot be null");
try { this.media.add(Media.builder().mimeType(mimeType).data(url.toURI()).build()); return this; } catch (URISyntaxException e) { throw new RuntimeException(e); } }
public ChatClient.PromptUserSpec media(MimeType mimeType, Resource resource) { Assert.notNull(mimeType, "mimeType cannot be null"); Assert.notNull(resource, "resource cannot be null"); this.media.add(Media.builder().mimeType(mimeType).data(resource).build()); return this; }
public ChatClient.PromptUserSpec text(String text) { Assert.hasText(text, "text cannot be null or empty"); this.text = text; return this; }
public ChatClient.PromptUserSpec text(Resource text, Charset charset) { Assert.notNull(text, "text cannot be null"); Assert.notNull(charset, "charset cannot be null");
try { this.text(text.getContentAsString(charset)); return this; } catch (IOException e) { throw new RuntimeException(e); } }
public ChatClient.PromptUserSpec text(Resource text) { Assert.notNull(text, "text cannot be null"); this.text(text, Charset.defaultCharset()); return this; }
public ChatClient.PromptUserSpec param(String key, Object value) { Assert.hasText(key, "key cannot be null or empty"); Assert.notNull(value, "value cannot be null"); this.params.put(key, value); return this; }
public ChatClient.PromptUserSpec params(Map<String, Object> params) { Assert.notNull(params, "params cannot be null"); Assert.noNullElements(params.keySet(), "param keys cannot contain null elements"); Assert.noNullElements(params.values(), "param values cannot contain null elements"); this.params.putAll(params); return this; }
@Nullable protected String text() { return this.text; }
protected Map<String, Object> params() { return this.params; }
protected List<Media> media() { return this.media; } }
public static class DefaultPromptSystemSpec implements ChatClient.PromptSystemSpec { private final Map<String, Object> params = new HashMap(); @Nullable private String text;
public ChatClient.PromptSystemSpec text(String text) { Assert.hasText(text, "text cannot be null or empty"); this.text = text; return this; }
public ChatClient.PromptSystemSpec text(Resource text, Charset charset) { Assert.notNull(text, "text cannot be null"); Assert.notNull(charset, "charset cannot be null");
try { this.text(text.getContentAsString(charset)); return this; } catch (IOException e) { throw new RuntimeException(e); } }
public ChatClient.PromptSystemSpec text(Resource text) { Assert.notNull(text, "text cannot be null"); this.text(text, Charset.defaultCharset()); return this; }
public ChatClient.PromptSystemSpec param(String key, Object value) { Assert.hasText(key, "key cannot be null or empty"); Assert.notNull(value, "value cannot be null"); this.params.put(key, value); return this; }
public ChatClient.PromptSystemSpec params(Map<String, Object> params) { Assert.notNull(params, "params cannot be null"); Assert.noNullElements(params.keySet(), "param keys cannot contain null elements"); Assert.noNullElements(params.values(), "param values cannot contain null elements"); this.params.putAll(params); return this; }
@Nullable protected String text() { return this.text; }
protected Map<String, Object> params() { return this.params; } }
public static class DefaultAdvisorSpec implements ChatClient.AdvisorSpec { private final List<Advisor> advisors = new ArrayList(); private final Map<String, Object> params = new HashMap();
public ChatClient.AdvisorSpec param(String key, Object value) { Assert.hasText(key, "key cannot be null or empty"); Assert.notNull(value, "value cannot be null"); this.params.put(key, value); return this; }
public ChatClient.AdvisorSpec params(Map<String, Object> params) { Assert.notNull(params, "params cannot be null"); Assert.noNullElements(params.keySet(), "param keys cannot contain null elements"); Assert.noNullElements(params.values(), "param values cannot contain null elements"); this.params.putAll(params); return this; }
public ChatClient.AdvisorSpec advisors(Advisor... advisors) { Assert.notNull(advisors, "advisors cannot be null"); Assert.noNullElements(advisors, "advisors cannot contain null elements"); this.advisors.addAll(List.of(advisors)); return this; }
public ChatClient.AdvisorSpec advisors(List<Advisor> advisors) { Assert.notNull(advisors, "advisors cannot be null"); Assert.noNullElements(advisors, "advisors cannot contain null elements"); this.advisors.addAll(advisors); return this; }
public List<Advisor> getAdvisors() { return this.advisors; }
public Map<String, Object> getParams() { return this.params; } }
public static class DefaultCallResponseSpec implements ChatClient.CallResponseSpec { private final ChatClientRequest request; private final BaseAdvisorChain advisorChain; private final ObservationRegistry observationRegistry; private final ChatClientObservationConvention observationConvention;
public DefaultCallResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain, ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) { Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); Assert.notNull(advisorChain, "advisorChain cannot be null"); Assert.notNull(observationRegistry, "observationRegistry cannot be null"); Assert.notNull(observationConvention, "observationConvention cannot be null"); this.request = chatClientRequest; this.advisorChain = advisorChain; this.observationRegistry = observationRegistry; this.observationConvention = observationConvention; }
public <T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type) { Assert.notNull(type, "type cannot be null"); return this.doResponseEntity(new BeanOutputConverter(type)); }
public <T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type) { Assert.notNull(type, "type cannot be null"); return this.doResponseEntity(new BeanOutputConverter(type)); }
public <T> ResponseEntity<ChatResponse, T> responseEntity(StructuredOutputConverter<T> structuredOutputConverter) { Assert.notNull(structuredOutputConverter, "structuredOutputConverter cannot be null"); return this.doResponseEntity(structuredOutputConverter); }
protected <T> ResponseEntity<ChatResponse, T> doResponseEntity(StructuredOutputConverter<T> outputConverter) { Assert.notNull(outputConverter, "structuredOutputConverter cannot be null"); ChatResponse chatResponse = this.doGetObservableChatClientResponse(this.request, outputConverter.getFormat()).chatResponse(); String responseContent = getContentFromChatResponse(chatResponse); if (responseContent == null) { return new ResponseEntity(chatResponse, (Object)null); } else { T entity = (T)outputConverter.convert(responseContent); return new ResponseEntity(chatResponse, entity); } }
@Nullable public <T> T entity(ParameterizedTypeReference<T> type) { Assert.notNull(type, "type cannot be null"); return (T)this.doSingleWithBeanOutputConverter(new BeanOutputConverter(type)); }
@Nullable public <T> T entity(StructuredOutputConverter<T> structuredOutputConverter) { Assert.notNull(structuredOutputConverter, "structuredOutputConverter cannot be null"); return (T)this.doSingleWithBeanOutputConverter(structuredOutputConverter); }
@Nullable public <T> T entity(Class<T> type) { Assert.notNull(type, "type cannot be null"); BeanOutputConverter<T> outputConverter = new BeanOutputConverter(type); return (T)this.doSingleWithBeanOutputConverter(outputConverter); }
@Nullable private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> outputConverter) { ChatResponse chatResponse = this.doGetObservableChatClientResponse(this.request, outputConverter.getFormat()).chatResponse(); String stringResponse = getContentFromChatResponse(chatResponse); return (T)(stringResponse == null ? null : outputConverter.convert(stringResponse)); }
public ChatClientResponse chatClientResponse() { return this.doGetObservableChatClientResponse(this.request); }
@Nullable public ChatResponse chatResponse() { return this.doGetObservableChatClientResponse(this.request).chatResponse(); }
@Nullable public String content() { ChatResponse chatResponse = this.doGetObservableChatClientResponse(this.request).chatResponse(); return getContentFromChatResponse(chatResponse); }
private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest) { return this.doGetObservableChatClientResponse(chatClientRequest, (String)null); }
private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest, @Nullable String outputFormat) { if (outputFormat != null) { chatClientRequest.context().put(ChatClientAttributes.OUTPUTFORMAT.getKey(), outputFormat); }
ChatClientObservationContext observationContext = ChatClientObservationContext.builder().request(chatClientRequest).advisors(this.advisorChain.getCallAdvisors()).stream(false).format(outputFormat).build(); Observation observation = ChatClientObservationDocumentation.AICHATCLIENT.observation(this.observationConvention, DefaultChatClient.DEFAULTCHATCLIENTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); ChatClientResponse chatClientResponse = (ChatClientResponse)observation.observe(() -> this.advisorChain.nextCall(chatClientRequest)); return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build(); }
@Nullable private static String getContentFromChatResponse(@Nullable ChatResponse chatResponse) { return (String)Optional.ofNullable(chatResponse).map(ChatResponse::getResult).map(Generation::getOutput).map(AbstractMessage::getText).orElse((Object)null); } }
public static class DefaultStreamResponseSpec implements ChatClient.StreamResponseSpec { private final ChatClientRequest request; private final BaseAdvisorChain advisorChain; private final ObservationRegistry observationRegistry; private final ChatClientObservationConvention observationConvention;
public DefaultStreamResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain, ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) { Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); Assert.notNull(advisorChain, "advisorChain cannot be null"); Assert.notNull(observationRegistry, "observationRegistry cannot be null"); Assert.notNull(observationConvention, "observationConvention cannot be null"); this.request = chatClientRequest; this.advisorChain = advisorChain; this.observationRegistry = observationRegistry; this.observationConvention = observationConvention; }
private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) { return Flux.deferContextual((contextView) -> { ChatClientObservationContext observationContext = ChatClientObservationContext.builder().request(chatClientRequest).advisors(this.advisorChain.getStreamAdvisors()).stream(true).build(); Observation observation = ChatClientObservationDocumentation.AICHATCLIENT.observation(this.observationConvention, DefaultChatClient.DEFAULTCHATCLIENTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); observation.parentObservation((Observation)contextView.getOrDefault("micrometer.observation", (Object)null)).start(); Flux var10000 = this.advisorChain.nextStream(chatClientRequest); Objects.requireNonNull(observation); return var10000.doOnError(observation::error).doFinally((s) -> observation.stop()).contextWrite((ctx) -> ctx.put("micrometer.observation", observation)); }); }
public Flux<ChatClientResponse> chatClientResponse() { return this.doGetObservableFluxChatResponse(this.request); }
public Flux<ChatResponse> chatResponse() { return this.doGetObservableFluxChatResponse(this.request).mapNotNull(ChatClientResponse::chatResponse); }
public Flux<String> content() { return this.doGetObservableFluxChatResponse(this.request).mapNotNull(ChatClientResponse::chatResponse).map((r) -> r.getResult() != null && r.getResult().getOutput() != null && r.getResult().getOutput().getText() != null ? r.getResult().getOutput().getText() : "").filter(StringUtils::hasLength); } }
public static class DefaultChatClientRequestSpec implements ChatClient.ChatClientRequestSpec { private final ObservationRegistry observationRegistry; private final ChatClientObservationConvention observationConvention; private final ChatModel chatModel; private final List<Media> media; private final List<String> toolNames; private final List<ToolCallback> toolCallbacks; private final List<Message> messages; private final Map<String, Object> userParams; private final Map<String, Object> systemParams; private final List<Advisor> advisors; private final Map<String, Object> advisorParams; private final Map<String, Object> toolContext; private TemplateRenderer templateRenderer; @Nullable private String userText; @Nullable private String systemText; @Nullable private ChatOptions chatOptions;
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention, ccr.toolContext, ccr.templateRenderer); }
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams, List<ToolCallback> toolCallbacks, List<Message> messages, List<String> toolNames, List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext, @Nullable TemplateRenderer templateRenderer) { this.media = new ArrayList(); this.toolNames = new ArrayList(); this.toolCallbacks = new ArrayList(); this.messages = new ArrayList(); this.userParams = new HashMap(); this.systemParams = new HashMap(); this.advisors = new ArrayList(); this.advisorParams = new HashMap(); this.toolContext = new HashMap(); Assert.notNull(chatModel, "chatModel cannot be null"); Assert.notNull(userParams, "userParams cannot be null"); Assert.notNull(systemParams, "systemParams cannot be null"); Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.notNull(messages, "messages cannot be null"); Assert.notNull(toolNames, "toolNames cannot be null"); Assert.notNull(media, "media cannot be null"); Assert.notNull(advisors, "advisors cannot be null"); Assert.notNull(advisorParams, "advisorParams cannot be null"); Assert.notNull(observationRegistry, "observationRegistry cannot be null"); Assert.notNull(toolContext, "toolContext cannot be null"); this.chatModel = chatModel; this.chatOptions = chatOptions != null ? chatOptions.copy() : (chatModel.getDefaultOptions() != null ? chatModel.getDefaultOptions().copy() : null); this.userText = userText; this.userParams.putAll(userParams); this.systemText = systemText; this.systemParams.putAll(systemParams); this.toolNames.addAll(toolNames); this.toolCallbacks.addAll(toolCallbacks); this.messages.addAll(messages); this.media.addAll(media); this.advisors.addAll(advisors); this.advisorParams.putAll(advisorParams); this.observationRegistry = observationRegistry; this.observationConvention = observationConvention != null ? observationConvention : DefaultChatClient.DEFAULTCHATCLIENTOBSERVATIONCONVENTION; this.toolContext.putAll(toolContext); this.templateRenderer = templateRenderer != null ? templateRenderer : DefaultChatClient.DEFAULTTEMPLATERENDERER; }
@Nullable public String getUserText() { return this.userText; }
public Map<String, Object> getUserParams() { return this.userParams; }
@Nullable public String getSystemText() { return this.systemText; }
public Map<String, Object> getSystemParams() { return this.systemParams; }
@Nullable public ChatOptions getChatOptions() { return this.chatOptions; }
public List<Advisor> getAdvisors() { return this.advisors; }
public Map<String, Object> getAdvisorParams() { return this.advisorParams; }
public List<Message> getMessages() { return this.messages; }
public List<Media> getMedia() { return this.media; }
public List<String> getToolNames() { return this.toolNames; }
public List<ToolCallback> getToolCallbacks() { return this.toolCallbacks; }
public Map<String, Object> getToolContext() { return this.toolContext; }
public TemplateRenderer getTemplateRenderer() { return this.templateRenderer; }
public ChatClient.Builder mutate() { DefaultChatClientBuilder builder = (DefaultChatClientBuilder)ChatClient.builder(this.chatModel, this.observationRegistry, this.observationConvention).defaultTemplateRenderer(this.templateRenderer).defaultToolCallbacks(this.toolCallbacks).defaultToolContext(this.toolContext).defaultToolNames(StringUtils.toStringArray(this.toolNames)); if (StringUtils.hasText(this.userText)) { builder.defaultUser((u) -> u.text(this.userText).params(this.userParams).media((Media[])this.media.toArray(new Media[0]))); }
if (StringUtils.hasText(this.systemText)) { builder.defaultSystem((s) -> s.text(this.systemText).params(this.systemParams)); }
if (this.chatOptions != null) { builder.defaultOptions(this.chatOptions); }
builder.addMessages(this.messages); return builder; }
public ChatClient.ChatClientRequestSpec advisors(Consumer<ChatClient.AdvisorSpec> consumer) { Assert.notNull(consumer, "consumer cannot be null"); DefaultAdvisorSpec advisorSpec = new DefaultAdvisorSpec(); consumer.accept(advisorSpec); this.advisorParams.putAll(advisorSpec.getParams()); this.advisors.addAll(advisorSpec.getAdvisors()); return this; }
public ChatClient.ChatClientRequestSpec advisors(Advisor... advisors) { Assert.notNull(advisors, "advisors cannot be null"); Assert.noNullElements(advisors, "advisors cannot contain null elements"); this.advisors.addAll(Arrays.asList(advisors)); return this; }
public ChatClient.ChatClientRequestSpec advisors(List<Advisor> advisors) { Assert.notNull(advisors, "advisors cannot be null"); Assert.noNullElements(advisors, "advisors cannot contain null elements"); this.advisors.addAll(advisors); return this; }
public ChatClient.ChatClientRequestSpec messages(Message... messages) { Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); this.messages.addAll(List.of(messages)); return this; }
public ChatClient.ChatClientRequestSpec messages(List<Message> messages) { Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); this.messages.addAll(messages); return this; }
public <T extends ChatOptions> ChatClient.ChatClientRequestSpec options(T options) { Assert.notNull(options, "options cannot be null"); this.chatOptions = options; return this; }
public ChatClient.ChatClientRequestSpec toolNames(String... toolNames) { Assert.notNull(toolNames, "toolNames cannot be null"); Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); this.toolNames.addAll(List.of(toolNames)); return this; }
public ChatClient.ChatClientRequestSpec toolCallbacks(ToolCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks.addAll(List.of(toolCallbacks)); return this; }
public ChatClient.ChatClientRequestSpec toolCallbacks(List<ToolCallback> toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); this.toolCallbacks.addAll(toolCallbacks); return this; }
public ChatClient.ChatClientRequestSpec tools(Object... toolObjects) { Assert.notNull(toolObjects, "toolObjects cannot be null"); Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements"); this.toolCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects))); return this; }
public ChatClient.ChatClientRequestSpec toolCallbacks(ToolCallbackProvider... toolCallbackProviders) { Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null"); Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements");
for(ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) { this.toolCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks())); }
return this; }
public ChatClient.ChatClientRequestSpec toolContext(Map<String, Object> toolContext) { Assert.notNull(toolContext, "toolContext cannot be null"); Assert.noNullElements(toolContext.keySet(), "toolContext keys cannot contain null elements"); Assert.noNullElements(toolContext.values(), "toolContext values cannot contain null elements"); this.toolContext.putAll(toolContext); return this; }
public ChatClient.ChatClientRequestSpec system(String text) { Assert.hasText(text, "text cannot be null or empty"); this.systemText = text; return this; }
public ChatClient.ChatClientRequestSpec system(Resource text, Charset charset) { Assert.notNull(text, "text cannot be null"); Assert.notNull(charset, "charset cannot be null");
try { this.systemText = text.getContentAsString(charset); return this; } catch (IOException e) { throw new RuntimeException(e); } }
public ChatClient.ChatClientRequestSpec system(Resource text) { Assert.notNull(text, "text cannot be null"); return this.system(text, Charset.defaultCharset()); }
public ChatClient.ChatClientRequestSpec system(Consumer<ChatClient.PromptSystemSpec> consumer) { Assert.notNull(consumer, "consumer cannot be null"); DefaultPromptSystemSpec systemSpec = new DefaultPromptSystemSpec(); consumer.accept(systemSpec); this.systemText = StringUtils.hasText(systemSpec.text()) ? systemSpec.text() : this.systemText; this.systemParams.putAll(systemSpec.params()); return this; }
public ChatClient.ChatClientRequestSpec user(String text) { Assert.hasText(text, "text cannot be null or empty"); this.userText = text; return this; }
public ChatClient.ChatClientRequestSpec user(Resource text, Charset charset) { Assert.notNull(text, "text cannot be null"); Assert.notNull(charset, "charset cannot be null");
try { this.userText = text.getContentAsString(charset); return this; } catch (IOException e) { throw new RuntimeException(e); } }
public ChatClient.ChatClientRequestSpec user(Resource text) { Assert.notNull(text, "text cannot be null"); return this.user(text, Charset.defaultCharset()); }
public ChatClient.ChatClientRequestSpec user(Consumer<ChatClient.PromptUserSpec> consumer) { Assert.notNull(consumer, "consumer cannot be null"); DefaultPromptUserSpec us = new DefaultPromptUserSpec(); consumer.accept(us); this.userText = StringUtils.hasText(us.text()) ? us.text() : this.userText; this.userParams.putAll(us.params()); this.media.addAll(us.media()); return this; }
public ChatClient.ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer) { Assert.notNull(templateRenderer, "templateRenderer cannot be null"); this.templateRenderer = templateRenderer; return this; }
public ChatClient.CallResponseSpec call() { BaseAdvisorChain advisorChain = this.buildAdvisorChain(); return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain, this.observationRegistry, this.observationConvention); }
public ChatClient.StreamResponseSpec stream() { BaseAdvisorChain advisorChain = this.buildAdvisorChain(); return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain, this.observationRegistry, this.observationConvention); }
private BaseAdvisorChain buildAdvisorChain() { this.advisors.add(ChatModelCallAdvisor.builder().chatModel(this.chatModel).build()); this.advisors.add(ChatModelStreamAdvisor.builder().chatModel(this.chatModel).build()); return DefaultAroundAdvisorChain.builder(this.observationRegistry).pushAll(this.advisors).templateRenderer(this.templateRenderer).build(); } }}
DefaultChatClientUtils
类作用:用来将 DefaultChatClient.DefaultChatClientRequestSpec 转换为 ChatClientRequest
- 处理系统提示
- 处理用户提示
- 处理工具调用选项
package org.springframework.ai.chat.client;
import java.util.ArrayList;import java.util.HashSet;import java.util.List;import java.util.Map;import java.util.Set;import java.util.concurrent.ConcurrentHashMap;import org.springframework.ai.chat.messages.Message;import org.springframework.ai.chat.messages.SystemMessage;import org.springframework.ai.chat.messages.UserMessage;import org.springframework.ai.chat.prompt.ChatOptions;import org.springframework.ai.chat.prompt.Prompt;import org.springframework.ai.chat.prompt.PromptTemplate;import org.springframework.ai.model.tool.ToolCallingChatOptions;import org.springframework.ai.tool.ToolCallback;import org.springframework.util.Assert;import org.springframework.util.CollectionUtils;import org.springframework.util.StringUtils;
final class DefaultChatClientUtils { private DefaultChatClientUtils() { }
static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClientRequestSpec inputRequest) { Assert.notNull(inputRequest, "inputRequest cannot be null"); List<Message> processedMessages = new ArrayList(); String processedSystemText = inputRequest.getSystemText(); if (StringUtils.hasText(processedSystemText)) { if (!CollectionUtils.isEmpty(inputRequest.getSystemParams())) { processedSystemText = PromptTemplate.builder().template(processedSystemText).variables(inputRequest.getSystemParams()).renderer(inputRequest.getTemplateRenderer()).build().render(); }
processedMessages.add(new SystemMessage(processedSystemText)); }
if (!CollectionUtils.isEmpty(inputRequest.getMessages())) { processedMessages.addAll(inputRequest.getMessages()); }
String processedUserText = inputRequest.getUserText(); if (StringUtils.hasText(processedUserText)) { if (!CollectionUtils.isEmpty(inputRequest.getUserParams())) { processedUserText = PromptTemplate.builder().template(processedUserText).variables(inputRequest.getUserParams()).renderer(inputRequest.getTemplateRenderer()).build().render(); }
processedMessages.add(UserMessage.builder().text(processedUserText).media(inputRequest.getMedia()).build()); }
ChatOptions processedChatOptions = inputRequest.getChatOptions(); if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { if (!inputRequest.getToolNames().isEmpty()) { Set<String> toolNames = ToolCallingChatOptions.mergeToolNames(new HashSet(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames()); toolCallingChatOptions.setToolNames(toolNames); }
if (!inputRequest.getToolCallbacks().isEmpty()) { List<ToolCallback> toolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks()); ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); toolCallingChatOptions.setToolCallbacks(toolCallbacks); }
if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { Map<String, Object> toolContext = ToolCallingChatOptions.mergeToolContext(inputRequest.getToolContext(), toolCallingChatOptions.getToolContext()); toolCallingChatOptions.setToolContext(toolContext); } }
return ChatClientRequest.builder().prompt(Prompt.builder().messages(processedMessages).chatOptions(processedChatOptions).build()).context(new ConcurrentHashMap(inputRequest.getAdvisorParams())).build(); }}
AdvisorChain
AdvisorChain 链调用一系列的增强器 Advisor 基础,每个增强器输入是 ChatClientRequest,输出 ChatClientResponse(其中必定会用到的是 ChatModelCallAdvisor 或 ChatModelStreamAdvisor)
- ChatModelCallAdvisor 触发 ChatModel 的 call 方法
- ChatModelStreamAdvisor 触发 ChatModel 的 stream 方法
ChatModel
package org.springframework.ai.chat.model;
import java.util.Arrays;import org.springframework.ai.chat.messages.Message;import org.springframework.ai.chat.messages.UserMessage;import org.springframework.ai.chat.prompt.ChatOptions;import org.springframework.ai.chat.prompt.Prompt;import org.springframework.ai.model.Model;import reactor.core.publisher.Flux;
public interface ChatModel extends Model<Prompt, ChatResponse>, StreamingChatModel { default String call(String message) { Prompt prompt = new Prompt(new UserMessage(message)); Generation generation = this.call(prompt).getResult(); return generation != null ? generation.getOutput().getText() : ""; }
default String call(Message... messages) { Prompt prompt = new Prompt(Arrays.asList(messages)); Generation generation = this.call(prompt).getResult(); return generation != null ? generation.getOutput().getText() : ""; }
ChatResponse call(Prompt prompt);
default ChatOptions getDefaultOptions() { return ChatOptions.builder().build(); }
default Flux<ChatResponse> stream(Prompt prompt) { throw new UnsupportedOperationException("streaming is not supported"); }}
不同厂商实现各种的 ChaModel,但实现逻辑基本以 OpenAI 作为官方实现
pom 引入对应依赖
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-autoconfigure-model-openai</artifactId></dependency>
OpenAiChatModel
各字段说明
字段名 | 类型 | 描述 |
defaultOptions | OpenAiChatOptions | 请求参数配置,如temperature、最大 token 数等 |
retryTemplate | RetryTemplate | 用于执行重试逻辑,适用于网络不稳定或 API 限流等 |
openAiApi | OpenAiApi | 封装OpenAI官方API的调用接口 |
observationRegistry | ObservationRegistry | 用于注册和记录观测日志,便于监控和分析调用过程 |
toolCallingManager | ToolCallingManager | 工具调用管理器,用于解析并执行工具调 |
toolExecutionEligibilityPredicate | ToolExecutionEligibilityPredicate | 判断是否需要执行工具调用的断言函数 |
observationConvention | ChatModelObservationConvention | 自定义观测日志格式的约定对象 |
对外暴露的方法
方法名 | 描述 |
call | 发起一次同步请求,返回完整的 ChatResponse,实际调用内部的internalCall方法 |
internalCall | 1. 构建OpenAI请求对象 2. 创建观测上下文 3. 执行带观测的模型调用 4. 执行OpenAI接口调用 5. 解析模型返回的choices 6. 将每个choice转换为Generation对象,构建完整的 7. 提取限流信息(RateLimit) 8. 计算token使用量 9. 构建最终的ChatResponse并设置上下文 10. 工具调用处理 |
stream | 发起一次流式请求,返回Flux |
internalStream | 1. 使用Flux.deferContextual延迟执行,保持上下文一致性 2. 构建OpenAI流式请求对象 3. 发起流式 API 调用,获取 chunk 数据 4. 创建角色映射表,解决 chunk 中 role 缺失问题 5. 创建观测上下文 6. 启动观测操作 7. 将 chunk 转换为 ChatCompletion 标准格式 8. 转换为 ChatResponse 并构建生成内 9. 处理 usage 字段(仅最终 chunk 包含完整 usage) 10. 工具调用处理 11. 聚合消息流并设置响应 |
getDefaultOptions | 回当前模型使用的默认请求参数,OpenAiChatOptions |
setObservationConvention | 设置自定义的观测日志格式化规则 |
mutate | 复制OpenAiChatModel实例 |
package org.springframework.ai.openai;
import io.micrometer.observation.Observation;import io.micrometer.observation.ObservationRegistry;import java.util.ArrayList;import java.util.Base64;import java.util.Collection;import java.util.HashMap;import java.util.List;import java.util.Map;import java.util.Objects;import java.util.concurrent.ConcurrentHashMap;import java.util.stream.Collectors;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.ai.chat.messages.AssistantMessage;import org.springframework.ai.chat.messages.MessageType;import org.springframework.ai.chat.messages.ToolResponseMessage;import org.springframework.ai.chat.messages.UserMessage;import org.springframework.ai.chat.metadata.ChatGenerationMetadata;import org.springframework.ai.chat.metadata.ChatResponseMetadata;import org.springframework.ai.chat.metadata.DefaultUsage;import org.springframework.ai.chat.metadata.EmptyUsage;import org.springframework.ai.chat.metadata.RateLimit;import org.springframework.ai.chat.metadata.Usage;import org.springframework.ai.chat.model.ChatModel;import org.springframework.ai.chat.model.ChatResponse;import org.springframework.ai.chat.model.Generation;import org.springframework.ai.chat.model.MessageAggregator;import org.springframework.ai.chat.observation.ChatModelObservationContext;import org.springframework.ai.chat.observation.ChatModelObservationConvention;import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;import org.springframework.ai.chat.prompt.ChatOptions;import org.springframework.ai.chat.prompt.Prompt;import org.springframework.ai.content.Media;import org.springframework.ai.model.ModelOptionsUtils;import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;import org.springframework.ai.model.tool.ToolCallingChatOptions;import org.springframework.ai.model.tool.ToolCallingManager;import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;import org.springframework.ai.model.tool.ToolExecutionResult;import org.springframework.ai.openai.api.OpenAiApi;import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio.Format;import org.springframework.ai.openai.api.common.OpenAiApiConstants;import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;import org.springframework.ai.retry.RetryUtils;import org.springframework.ai.support.UsageCalculator;import org.springframework.ai.tool.definition.ToolDefinition;import org.springframework.core.io.ByteArrayResource;import org.springframework.core.io.Resource;import org.springframework.http.ResponseEntity;import org.springframework.retry.support.RetryTemplate;import org.springframework.util.Assert;import org.springframework.util.CollectionUtils;import org.springframework.util.MimeType;import org.springframework.util.MimeTypeUtils;import org.springframework.util.MultiValueMap;import org.springframework.util.StringUtils;import reactor.core.publisher.Flux;import reactor.core.publisher.Mono;import reactor.core.scheduler.Schedulers;
public class OpenAiChatModel implements ChatModel { private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel.class); private static final ChatModelObservationConvention DEFAULTOBSERVATIONCONVENTION = new DefaultChatModelObservationConvention(); private static final ToolCallingManager DEFAULTTOOLCALLINGMANAGER = ToolCallingManager.builder().build(); private final OpenAiChatOptions defaultOptions; private final RetryTemplate retryTemplate; private final OpenAiApi openAiApi; private final ObservationRegistry observationRegistry; private final ToolCallingManager toolCallingManager; private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; private ChatModelObservationConvention observationConvention;
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { this(openAiApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry, new DefaultToolExecutionEligibilityPredicate()); }
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { this.observationConvention = DEFAULTOBSERVATIONCONVENTION; Assert.notNull(openAiApi, "openAiApi cannot be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); Assert.notNull(retryTemplate, "retryTemplate cannot be null"); Assert.notNull(observationRegistry, "observationRegistry cannot be null"); Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null"); this.openAiApi = openAiApi; this.defaultOptions = defaultOptions; this.toolCallingManager = toolCallingManager; this.retryTemplate = retryTemplate; this.observationRegistry = observationRegistry; this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; }
public ChatResponse call(Prompt prompt) { Prompt requestPrompt = this.buildRequestPrompt(prompt); return this.internalCall(requestPrompt, (ChatResponse)null); }
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { OpenAiApi.ChatCompletionRequest request = this.createRequest(prompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OpenAiApiConstants.PROVIDERNAME).build(); ChatResponse response = (ChatResponse)ChatModelObservationDocumentation.CHATMODELOPERATION.observation(this.observationConvention, DEFAULTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry).observe(() -> { ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = (ResponseEntity)this.retryTemplate.execute((ctx) -> this.openAiApi.chatCompletionEntity(request, this.getAdditionalHttpHeaders(prompt))); OpenAiApi.ChatCompletion chatCompletion = (OpenAiApi.ChatCompletion)completionEntity.getBody(); if (chatCompletion == null) { logger.warn("No chat completion returned for prompt: {}", prompt); return new ChatResponse(List.of()); } else { List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices(); if (choices == null) { logger.warn("No choices returned for prompt: {}", prompt); return new ChatResponse(List.of()); } else { List<Generation> generations = choices.stream().map((choice) -> { Map<String, Object> metadata = Map.of("id", chatCompletion.id() != null ? chatCompletion.id() : "", "role", choice.message().role() != null ? choice.message().role().name() : "", "index", choice.index(), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "", "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of()); return this.buildGeneration(choice, metadata, request); }).toList(); RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); OpenAiApi.Usage usage = chatCompletion.usage(); Usage currentChatResponseUsage = (Usage)(usage != null ? this.getDefaultUsage(usage) : new EmptyUsage()); Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); ChatResponse chatResponse = new ChatResponse(generations, this.from(chatCompletion, rateLimit, accumulatedUsage)); observationContext.setResponse(chatResponse); return chatResponse; } } }); if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); return toolExecutionResult.returnDirect() ? ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations(toolExecutionResult)).build() : this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response); } else { return response; } }
public Flux<ChatResponse> stream(Prompt prompt) { Prompt requestPrompt = this.buildRequestPrompt(prompt); return this.internalStream(requestPrompt, (ChatResponse)null); }
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) { return Flux.deferContextual((contextView) -> { OpenAiApi.ChatCompletionRequest request = this.createRequest(prompt, true); if (request.outputModalities() != null && request.outputModalities().stream().anyMatch((m) -> m.equals("audio"))) { logger.warn("Audio output is not supported for streaming requests. Removing audio output."); throw new IllegalArgumentException("Audio output is not supported for streaming requests."); } else if (request.audioParameters() != null) { logger.warn("Audio parameters are not supported for streaming requests. Removing audio parameters."); throw new IllegalArgumentException("Audio parameters are not supported for streaming requests."); } else { Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request, this.getAdditionalHttpHeaders(prompt)); ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap(); ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OpenAiApiConstants.PROVIDERNAME).build(); Observation observation = ChatModelObservationDocumentation.CHATMODELOPERATION.observation(this.observationConvention, DEFAULTOBSERVATIONCONVENTION, () -> observationContext, this.observationRegistry); observation.parentObservation((Observation)contextView.getOrDefault("micrometer.observation", (Object)null)).start(); Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion).switchMap((chatCompletion) -> Mono.just(chatCompletion).map((chatCompletion2) -> { try { String id = chatCompletion2.id() == null ? "NOID" : chatCompletion2.id(); List<Generation> generations = chatCompletion2.choices().stream().map((choice) -> { if (choice.message().role() != null) { roleMap.putIfAbsent(id, choice.message().role().name()); }
Map<String, Object> metadata = Map.of("id", id, "role", roleMap.getOrDefault(id, ""), "index", choice.index(), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "", "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of()); return this.buildGeneration(choice, metadata, request); }).toList(); OpenAiApi.Usage usage = chatCompletion2.usage(); Usage currentChatResponseUsage = (Usage)(usage != null ? this.getDefaultUsage(usage) : new EmptyUsage()); Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); return new ChatResponse(generations, this.from(chatCompletion2, (RateLimit)null, accumulatedUsage)); } catch (Exception e) { logger.error("Error processing chat completion", e); return new ChatResponse(List.of()); } })).buffer(2, 1).map((bufferList) -> { ChatResponse firstResponse = (ChatResponse)bufferList.get(0); if (request.streamOptions() != null && request.streamOptions().includeUsage() && bufferList.size() == 2) { ChatResponse secondResponse = (ChatResponse)bufferList.get(1); if (secondResponse != null && secondResponse.getMetadata() != null) { Usage usage = secondResponse.getMetadata().getUsage(); if (!UsageCalculator.isEmpty(usage)) { return new ChatResponse(firstResponse.getResults(), this.from(firstResponse.getMetadata(), usage)); } } }
return firstResponse; }); Flux var10000 = chatResponse.flatMap((response) -> this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response) ? Flux.defer(() -> { ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); return toolExecutionResult.returnDirect() ? Flux.just(ChatResponse.builder().from(response).generations(ToolExecutionResult.buildGenerations(toolExecutionResult)).build()) : this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response); }).subscribeOn(Schedulers.boundedElastic()) : Flux.just(response)); Objects.requireNonNull(observation); Flux<ChatResponse> flux = var10000.doOnError(observation::error).doFinally((s) -> observation.stop()).contextWrite((ctx) -> ctx.put("micrometer.observation", observation)); MessageAggregator var11 = new MessageAggregator(); Objects.requireNonNull(observationContext); return var11.aggregate(flux, observationContext::setResponse); } }); }
private MultiValueMap<String, String> getAdditionalHttpHeaders(Prompt prompt) { Map<String, String> headers = new HashMap(this.defaultOptions.getHttpHeaders()); if (prompt.getOptions() != null) { ChatOptions var4 = prompt.getOptions(); if (var4 instanceof OpenAiChatOptions) { OpenAiChatOptions chatOptions = (OpenAiChatOptions)var4; headers.putAll(chatOptions.getHttpHeaders()); } }
return CollectionUtils.toMultiValueMap((Map)headers.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, (e) -> List.of((String)e.getValue())))); }
private Generation buildGeneration(OpenAiApi.ChatCompletion.Choice choice, Map<String, Object> metadata, OpenAiApi.ChatCompletionRequest request) { List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of() : choice.message().toolCalls().stream().map((toolCall) -> new AssistantMessage.ToolCall(toolCall.id(), "function", toolCall.function().name(), toolCall.function().arguments())).toList(); String finishReason = choice.finishReason() != null ? choice.finishReason().name() : ""; ChatGenerationMetadata.Builder generationMetadataBuilder = ChatGenerationMetadata.builder().finishReason(finishReason); List<Media> media = new ArrayList(); String textContent = choice.message().content(); OpenAiApi.ChatCompletionMessage.AudioOutput audioOutput = choice.message().audioOutput(); if (audioOutput != null) { String mimeType = String.format("audio/%s", request.audioParameters().format().name().toLowerCase()); byte[] audioData = Base64.getDecoder().decode(audioOutput.data()); Resource resource = new ByteArrayResource(audioData); Media.builder().mimeType(MimeTypeUtils.parseMimeType(mimeType)).data(resource).id(audioOutput.id()).build(); media.add(Media.builder().mimeType(MimeTypeUtils.parseMimeType(mimeType)).data(resource).id(audioOutput.id()).build()); if (!StringUtils.hasText(textContent)) { textContent = audioOutput.transcript(); }
generationMetadataBuilder.metadata("audioId", audioOutput.id()); generationMetadataBuilder.metadata("audioExpiresAt", audioOutput.expiresAt()); }
if (Boolean.TRUE.equals(request.logprobs())) { generationMetadataBuilder.metadata("logprobs", choice.logprobs()); }
AssistantMessage assistantMessage = new AssistantMessage(textContent, metadata, toolCalls, media); return new Generation(assistantMessage, generationMetadataBuilder.build()); }
private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit, Usage usage) { Assert.notNull(result, "OpenAI ChatCompletionResult must not be null"); ChatResponseMetadata.Builder builder = ChatResponseMetadata.builder().id(result.id() != null ? result.id() : "").usage(usage).model(result.model() != null ? result.model() : "").keyValue("created", result.created() != null ? result.created() : 0L).keyValue("system-fingerprint", result.systemFingerprint() != null ? result.systemFingerprint() : ""); if (rateLimit != null) { builder.rateLimit(rateLimit); }
return builder.build(); }
private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usage usage) { Assert.notNull(chatResponseMetadata, "OpenAI ChatResponseMetadata must not be null"); ChatResponseMetadata.Builder builder = ChatResponseMetadata.builder().id(chatResponseMetadata.getId() != null ? chatResponseMetadata.getId() : "").usage(usage).model(chatResponseMetadata.getModel() != null ? chatResponseMetadata.getModel() : ""); if (chatResponseMetadata.getRateLimit() != null) { builder.rateLimit(chatResponseMetadata.getRateLimit()); }
return builder.build(); }
private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionChunk chunk) { List<OpenAiApi.ChatCompletion.Choice> choices = chunk.choices().stream().map((chunkChoice) -> new OpenAiApi.ChatCompletion.Choice(chunkChoice.finishReason(), chunkChoice.index(), chunkChoice.delta(), chunkChoice.logprobs())).toList(); return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(), chunk.systemFingerprint(), "chat.completion", chunk.usage()); }
private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) { return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); }
Prompt buildRequestPrompt(Prompt prompt) { OpenAiChatOptions runtimeOptions = null; if (prompt.getOptions() != null) { ChatOptions var4 = prompt.getOptions(); if (var4 instanceof ToolCallingChatOptions) { ToolCallingChatOptions toolCallingChatOptions = (ToolCallingChatOptions)var4; runtimeOptions = (OpenAiChatOptions)ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, OpenAiChatOptions.class); } else { runtimeOptions = (OpenAiChatOptions)ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, OpenAiChatOptions.class); } }
OpenAiChatOptions requestOptions = (OpenAiChatOptions)ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OpenAiChatOptions.class); if (runtimeOptions != null) { if (runtimeOptions.getTopK() != null) { logger.warn("The topK option is not supported by OpenAI chat models. Ignoring."); }
requestOptions.setHttpHeaders(this.mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders())); requestOptions.setInternalToolExecutionEnabled((Boolean)ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks())); requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), this.defaultOptions.getToolContext())); } else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); }
ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); return new Prompt(prompt.getInstructions(), requestOptions); }
private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHeaders, Map<String, String> defaultHttpHeaders) { HashMap<String, String> mergedHttpHeaders = new HashMap(defaultHttpHeaders); mergedHttpHeaders.putAll(runtimeHttpHeaders); return mergedHttpHeaders; }
OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map((message) -> { if (message.getMessageType() != MessageType.USER && message.getMessageType() != MessageType.SYSTEM) { if (message.getMessageType() == MessageType.ASSISTANT) { AssistantMessage assistantMessage = (AssistantMessage)message; List<OpenAiApi.ChatCompletionMessage.ToolCall> toolCalls = null; if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { toolCalls = assistantMessage.getToolCalls().stream().map((toolCall) -> { OpenAiApi.ChatCompletionMessage.ChatCompletionFunction function = new OpenAiApi.ChatCompletionMessage.ChatCompletionFunction(toolCall.name(), toolCall.arguments()); return new OpenAiApi.ChatCompletionMessage.ToolCall(toolCall.id(), toolCall.type(), function); }).toList(); }
OpenAiApi.ChatCompletionMessage.AudioOutput audioOutput = null; if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) { Assert.isTrue(assistantMessage.getMedia().size() == 1, "Only one media content is supported for assistant messages"); audioOutput = new OpenAiApi.ChatCompletionMessage.AudioOutput(((Media)assistantMessage.getMedia().get(0)).getId(), (String)null, (Long)null, (String)null); }
return List.of(new OpenAiApi.ChatCompletionMessage(assistantMessage.getText(), Role.ASSISTANT, (String)null, (String)null, toolCalls, (String)null, audioOutput, (List)null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage)message; toolMessage.getResponses().forEach((response) -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id")); return toolMessage.getResponses().stream().map((tr) -> new OpenAiApi.ChatCompletionMessage(tr.responseData(), Role.TOOL, tr.name(), tr.id(), (List)null, (String)null, (OpenAiApi.ChatCompletionMessage.AudioOutput)null, (List)null)).toList(); } else { throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(message.getMessageType())); } } else { Object content = message.getText(); if (message instanceof UserMessage) { UserMessage userMessage = (UserMessage)message; if (!CollectionUtils.isEmpty(userMessage.getMedia())) { List<OpenAiApi.ChatCompletionMessage.MediaContent> contentList = new ArrayList(List.of(new OpenAiApi.ChatCompletionMessage.MediaContent(message.getText()))); contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList()); content = contentList; } }
return List.of(new OpenAiApi.ChatCompletionMessage(content, Role.valueOf(message.getMessageType().name()))); } }).flatMap(Collection::stream).toList(); OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream); OpenAiChatOptions requestOptions = (OpenAiChatOptions)prompt.getOptions(); request = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge(requestOptions, request, OpenAiApi.ChatCompletionRequest.class); List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); if (!CollectionUtils.isEmpty(toolDefinitions)) { request = (OpenAiApi.ChatCompletionRequest)ModelOptionsUtils.merge(OpenAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request, OpenAiApi.ChatCompletionRequest.class); }
if (request.streamOptions() != null && !stream) { logger.warn("Removing streamOptions from the request as it is not a streaming request!"); request = request.streamOptions((OpenAiApi.ChatCompletionRequest.StreamOptions)null); }
return request; }
private OpenAiApi.ChatCompletionMessage.MediaContent mapToMediaContent(Media media) { MimeType mimeType = media.getMimeType(); if (MimeTypeUtils.parseMimeType("audio/mp3").equals(mimeType)) { return new OpenAiApi.ChatCompletionMessage.MediaContent(new OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio(this.fromAudioData(media.getData()), Format.MP3)); } else { return MimeTypeUtils.parseMimeType("audio/wav").equals(mimeType) ? new OpenAiApi.ChatCompletionMessage.MediaContent(new OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio(this.fromAudioData(media.getData()), Format.WAV)) : new OpenAiApi.ChatCompletionMessage.MediaContent(new OpenAiApi.ChatCompletionMessage.MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))); } }
private String fromAudioData(Object audioData) { if (audioData instanceof byte[] bytes) { return Base64.getEncoder().encodeToString(bytes); } else { throw new IllegalArgumentException("Unsupported audio data type: " + audioData.getClass().getSimpleName()); } }
private String fromMediaData(MimeType mimeType, Object mediaContentData) { if (mediaContentData instanceof byte[] bytes) { return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes)); } else if (mediaContentData instanceof String text) { return text; } else { throw new IllegalArgumentException("Unsupported media data type: " + mediaContentData.getClass().getSimpleName()); } }
private List<OpenAiApi.FunctionTool> getFunctionTools(List<ToolDefinition> toolDefinitions) { return toolDefinitions.stream().map((toolDefinition) -> { OpenAiApi.FunctionTool.Function function = new OpenAiApi.FunctionTool.Function(toolDefinition.description(), toolDefinition.name(), toolDefinition.inputSchema()); return new OpenAiApi.FunctionTool(function); }).toList(); }
public ChatOptions getDefaultOptions() { return OpenAiChatOptions.fromOptions(this.defaultOptions); }
public String toString() { return "OpenAiChatModel [defaultOptions=" + String.valueOf(this.defaultOptions) + "]"; }
public void setObservationConvention(ChatModelObservationConvention observationConvention) { Assert.notNull(observationConvention, "observationConvention cannot be null"); this.observationConvention = observationConvention; }
public static Builder builder() { return new Builder(); }
public Builder mutate() { return new Builder(this); }
public OpenAiChatModel clone() { return this.mutate().build(); }
public static final class Builder { private OpenAiApi openAiApi; private OpenAiChatOptions defaultOptions; private ToolCallingManager toolCallingManager; private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; private RetryTemplate retryTemplate; private ObservationRegistry observationRegistry;
public Builder(OpenAiChatModel model) { this.defaultOptions = OpenAiChatOptions.builder().model(OpenAiApi.DEFAULTCHATMODEL).temperature(0.7).build(); this.toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); this.retryTemplate = RetryUtils.DEFAULTRETRYTEMPLATE; this.observationRegistry = ObservationRegistry.NOOP; this.openAiApi = model.openAiApi; this.defaultOptions = model.defaultOptions; this.toolCallingManager = model.toolCallingManager; this.toolExecutionEligibilityPredicate = model.toolExecutionEligibilityPredicate; this.retryTemplate = model.retryTemplate; this.observationRegistry = model.observationRegistry; }
private Builder() { this.defaultOptions = OpenAiChatOptions.builder().model(OpenAiApi.DEFAULTCHATMODEL).temperature(0.7).build(); this.toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); this.retryTemplate = RetryUtils.DEFAULTRETRYTEMPLATE; this.observationRegistry = ObservationRegistry.NOOP; }
public Builder openAiApi(OpenAiApi openAiApi) { this.openAiApi = openAiApi; return this; }
public Builder defaultOptions(OpenAiChatOptions defaultOptions) { this.defaultOptions = defaultOptions; return this; }
public Builder toolCallingManager(ToolCallingManager toolCallingManager) { this.toolCallingManager = toolCallingManager; return this; }
public Builder toolExecutionEligibilityPredicate(ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; return this; }
public Builder retryTemplate(RetryTemplate retryTemplate) { this.retryTemplate = retryTemplate; return this; }
public Builder observationRegistry(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; return this; }
public OpenAiChatModel build() { return this.toolCallingManager != null ? new OpenAiChatModel(this.openAiApi, this.defaultOptions, this.toolCallingManager, this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate) : new OpenAiChatModel(this.openAiApi, this.defaultOptions, OpenAiChatModel.DEFAULTTOOLCALLINGMANAGER, this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate); } }}
OpenAiApi
各字段说明
字段名 | 类型 | 描述 |
baseUrl | String | OpenAI API 的基础 URL,默认为 "https://api.openai.com" |
apiKey | ApiKey | 认证密钥 |
headers | MultiValueMap | 自定义 HTTP 请求头,例如用户自定义的身份信息等 |
completionsPath | String | Chat Completion 接口路径,默认为 /v1/chat/completions |
embeddingsPath | String | Embedding 接口路径,默认为 /v1/embeddings |
responseErrorHandler | ResponseErrorHandler | 响应错误处理器,默认处理异常逻辑 |
restClient | RestClient | 同步请求客户端,用于非流式请求 |
webClient | WebClient | 异步/响应式请求客户端,用于流式请求 |
chunkMerger | OpenAiStreamFunctionCallingHelper | 流式函数调用合并器,用于处理多个 chunk 中的 functioncall 数据 |
对外暴露的方法
方法名 | 描述 |
chatCompletionEntity | 发送同步请求获取完整的 Chat Completion 响应 |
chatCompletionStream | 发起流式请求,接收分块响应(chunk) |
embeddings | 调用 OpenAI Embedding 接口,生成文本或 token 数组的向量表示 |
内部枚举类说明
枚举类 | 描述 |
ChatModel | 支持的聊天模型 |
ChatCompletionFinishReason | 模型停止生成的原因 |
EmbeddingModel | 支持的Embedding模型 |
OutputModality | 模型输出的范式 |
完整代码如下
package org.springframework.ai.openai.api;
import com.fasterxml.jackson.annotation.JsonFormat;import com.fasterxml.jackson.annotation.JsonIgnore;import com.fasterxml.jackson.annotation.JsonIgnoreProperties;import com.fasterxml.jackson.annotation.JsonInclude;import com.fasterxml.jackson.annotation.JsonProperty;import com.fasterxml.jackson.annotation.JsonFormat.Feature;import com.fasterxml.jackson.annotation.JsonInclude.Include;import java.util.List;import java.util.Map;import java.util.concurrent.atomic.AtomicBoolean;import java.util.function.Consumer;import java.util.function.Predicate;import org.springframework.ai.model.ApiKey;import org.springframework.ai.model.ChatModelDescription;import org.springframework.ai.model.ModelOptionsUtils;import org.springframework.ai.model.NoopApiKey;import org.springframework.ai.model.SimpleApiKey;import org.springframework.ai.retry.RetryUtils;import org.springframework.core.ParameterizedTypeReference;import org.springframework.http.HttpHeaders;import org.springframework.http.MediaType;import org.springframework.http.ResponseEntity;import org.springframework.util.Assert;import org.springframework.util.CollectionUtils;import org.springframework.util.LinkedMultiValueMap;import org.springframework.util.MultiValueMap;import org.springframework.web.client.ResponseErrorHandler;import org.springframework.web.client.RestClient;import org.springframework.web.reactive.function.client.WebClient;import reactor.core.publisher.Flux;import reactor.core.publisher.Mono;
public class OpenAiApi { public static final ChatModel DEFAULTCHATMODEL; public static final String DEFAULTEMBEDDINGMODEL; private static final Predicate<String> SSEDONEPREDICATE; private final String baseUrl; private final ApiKey apiKey; private final MultiValueMap<String, String> headers; private final String completionsPath; private final String embeddingsPath; private final ResponseErrorHandler responseErrorHandler; private final RestClient restClient; private final WebClient webClient; private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper();
public Builder mutate() { return new Builder(this); }
public static Builder builder() { return new Builder(); }
public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath, String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { this.baseUrl = baseUrl; this.apiKey = apiKey; this.headers = headers; this.completionsPath = completionsPath; this.embeddingsPath = embeddingsPath; this.responseErrorHandler = responseErrorHandler; Assert.hasText(completionsPath, "Completions Path must not be null"); Assert.hasText(embeddingsPath, "Embeddings Path must not be null"); Assert.notNull(headers, "Headers must not be null"); Consumer<HttpHeaders> finalHeaders = (h) -> { if (!(apiKey instanceof NoopApiKey)) { h.setBearerAuth(apiKey.getValue()); }
h.setContentType(MediaType.APPLICATIONJSON); h.addAll(headers); }; this.restClient = restClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(finalHeaders).defaultStatusHandler(responseErrorHandler).build(); this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(finalHeaders).build(); }
public static String getTextContent(List<ChatCompletionMessage.MediaContent> content) { return (String)content.stream().filter((c) -> "text".equals(c.type())).map(ChatCompletionMessage.MediaContent::text).reduce("", (a, b) -> a + b); }
public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest chatRequest) { return this.chatCompletionEntity(chatRequest, new LinkedMultiValueMap()); }
public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest chatRequest, MultiValueMap<String, String> additionalHttpHeader) { Assert.notNull(chatRequest, "The request body can not be null."); Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); return ((RestClient.RequestBodySpec)((RestClient.RequestBodySpec)this.restClient.post().uri(this.completionsPath, new Object[0])).headers((headers) -> headers.addAll(additionalHttpHeader))).body(chatRequest).retrieve().toEntity(ChatCompletion.class); }
public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chatRequest) { return this.chatCompletionStream(chatRequest, new LinkedMultiValueMap()); }
public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chatRequest, MultiValueMap<String, String> additionalHttpHeader) { Assert.notNull(chatRequest, "The request body can not be null."); Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true."); AtomicBoolean isInsideTool = new AtomicBoolean(false); return ((WebClient.RequestBodySpec)((WebClient.RequestBodySpec)this.webClient.post().uri(this.completionsPath, new Object[0])).headers((headers) -> headers.addAll(additionalHttpHeader))).body(Mono.just(chatRequest), ChatCompletionRequest.class).retrieve().bodyToFlux(String.class).takeUntil(SSEDONEPREDICATE).filter(SSEDONEPREDICATE.negate()).map((content) -> (ChatCompletionChunk)ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class)).map((chunk) -> { if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) { isInsideTool.set(true); }
return chunk; }).windowUntil((chunk) -> { if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) { isInsideTool.set(false); return true; } else { return !isInsideTool.get(); } }).concatMapIterable((window) -> { Mono<ChatCompletionChunk> monoChunk = window.reduce(new ChatCompletionChunk((String)null, (List)null, (Long)null, (String)null, (String)null, (String)null, (String)null, (Usage)null), (previous, current) -> this.chunkMerger.merge(previous, current)); return List.of(monoChunk); }).flatMap((mono) -> mono); }
public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<T> embeddingRequest) { Assert.notNull(embeddingRequest, "The request body can not be null."); Assert.notNull(embeddingRequest.input(), "The input can not be null."); Assert.isTrue(embeddingRequest.input() instanceof String || embeddingRequest.input() instanceof List, "The input must be either a String, or a List of Strings or List of List of integers."); Object var3 = embeddingRequest.input(); if (var3 instanceof List list) { Assert.isTrue(!CollectionUtils.isEmpty(list), "The input list can not be empty."); Assert.isTrue(list.size() <= 2048, "The list must be 2048 dimensions or less"); Assert.isTrue(list.get(0) instanceof String || list.get(0) instanceof Integer || list.get(0) instanceof List, "The input must be either a String, or a List of Strings or list of list of integers."); }
return ((RestClient.RequestBodySpec)this.restClient.post().uri(this.embeddingsPath, new Object[0])).body(embeddingRequest).retrieve().toEntity(new ParameterizedTypeReference<EmbeddingList<Embedding>>() { }); }
String getBaseUrl() { return this.baseUrl; }
ApiKey getApiKey() { return this.apiKey; }
MultiValueMap<String, String> getHeaders() { return this.headers; }
String getCompletionsPath() { return this.completionsPath; }
String getEmbeddingsPath() { return this.embeddingsPath; }
ResponseErrorHandler getResponseErrorHandler() { return this.responseErrorHandler; }
static { DEFAULTCHATMODEL = OpenAiApi.ChatModel.GPT4O; DEFAULTEMBEDDINGMODEL = OpenAiApi.EmbeddingModel.TEXTEMBEDDINGADA002.getValue(); SSEDONEPREDICATE = "[DONE]"::equals; }
public static enum ChatModel implements ChatModelDescription { O4MINI("o4-mini"), O3("o3"), O3MINI("o3-mini"), O1("o1"), O1MINI("o1-mini"), O1PRO("o1-pro"), GPT41("gpt-4.1"), GPT4O("gpt-4o"), CHATGPT4OLATEST("chatgpt-4o-latest"), GPT4OAUDIOPREVIEW("gpt-4o-audio-preview"), GPT41MINI("gpt-4.1-mini"), GPT41NANO("gpt-4.1-nano"), GPT4OMINI("gpt-4o-mini"), GPT4OMINIAUDIOPREVIEW("gpt-4o-mini-audio-preview"), GPT4OREALTIMEPREVIEW("gpt-4o-realtime-preview"), GPT4OMINIREALTIMEPREVIEW("gpt-4o-mini-realtime-preview\n"), GPT4TURBO("gpt-4-turbo"), GPT4("gpt-4"), GPT35TURBO("gpt-3.5-turbo"), GPT35TURBOINSTRUCT("gpt-3.5-turbo-instruct"), GPT4OSEARCHPREVIEW("gpt-4o-search-preview"), GPT4OMINISEARCHPREVIEW("gpt-4o-mini-search-preview");
public final String value;
private ChatModel(String value) { this.value = value; }
public String getValue() { return this.value; }
public String getName() { return this.value; } }
public static enum ChatCompletionFinishReason { @JsonProperty("stop") STOP, @JsonProperty("length") LENGTH, @JsonProperty("contentfilter") CONTENTFILTER, @JsonProperty("toolcalls") TOOLCALLS, @JsonProperty("toolcall") TOOLCALL; }
public static enum EmbeddingModel { TEXTEMBEDDING3LARGE("text-embedding-3-large"), TEXTEMBEDDING3SMALL("text-embedding-3-small"), TEXTEMBEDDINGADA002("text-embedding-ada-002");
public final String value;
private EmbeddingModel(String value) { this.value = value; }
public String getValue() { return this.value; } }
@JsonInclude(Include.NONNULL) public static class FunctionTool { @JsonProperty("type") private Type type; @JsonProperty("function") private Function function;
public FunctionTool() { this.type = OpenAiApi.FunctionTool.Type.FUNCTION; }
public FunctionTool(Type type, Function function) { this.type = OpenAiApi.FunctionTool.Type.FUNCTION; this.type = type; this.function = function; }
public FunctionTool(Function function) { this(OpenAiApi.FunctionTool.Type.FUNCTION, function); }
public Type getType() { return this.type; }
public Function getFunction() { return this.function; }
public void setType(Type type) { this.type = type; }
public void setFunction(Function function) { this.function = function; }
public static enum Type { @JsonProperty("function") FUNCTION; }
@JsonInclude(Include.NONNULL) public static class Function { @JsonProperty("description") private String description; @JsonProperty("name") private String name; @JsonProperty("parameters") private Map<String, Object> parameters; @JsonProperty("strict") Boolean strict; @JsonIgnore private String jsonSchema;
private Function() { }
public Function(String description, String name, Map<String, Object> parameters, Boolean strict) { this.description = description; this.name = name; this.parameters = parameters; this.strict = strict; }
public Function(String description, String name, String jsonSchema) { this(description, name, ModelOptionsUtils.jsonToMap(jsonSchema), (Boolean)null); }
public String getDescription() { return this.description; }
public String getName() { return this.name; }
public Map<String, Object> getParameters() { return this.parameters; }
public void setDescription(String description) { this.description = description; }
public void setName(String name) { this.name = name; }
public void setParameters(Map<String, Object> parameters) { this.parameters = parameters; }
public Boolean getStrict() { return this.strict; }
public void setStrict(Boolean strict) { this.strict = strict; }
public String getJsonSchema() { return this.jsonSchema; }
public void setJsonSchema(String jsonSchema) { this.jsonSchema = jsonSchema; if (jsonSchema != null) { this.parameters = ModelOptionsUtils.jsonToMap(jsonSchema); }
} } }
public static enum OutputModality { @JsonProperty("audio") AUDIO, @JsonProperty("text") TEXT; }
@JsonInclude(Include.NONNULL) public static record ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Boolean store, Map<String, String> metadata, Double frequencyPenalty, Map<String, Integer> logitBias, Boolean logprobs, Integer topLogprobs, Integer maxTokens, Integer maxCompletionTokens, Integer n, List<OutputModality> outputModalities, AudioParameters audioParameters, Double presencePenalty, ResponseFormat responseFormat, Integer seed, String serviceTier, List<String> stop, Boolean stream, StreamOptions streamOptions, Double temperature, Double topP, List<FunctionTool> tools, Object toolChoice, Boolean parallelToolCalls, String user, String reasoningEffort, WebSearchOptions webSearchOptions) { public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) { this(messages, model, (Boolean)null, (Map)null, (Double)null, (Map)null, (Boolean)null, (Integer)null, (Integer)null, (Integer)null, (Integer)null, (List)null, (AudioParameters)null, (Double)null, (ResponseFormat)null, (Integer)null, (String)null, (List)null, false, (StreamOptions)null, temperature, (Double)null, (List)null, (Object)null, (Boolean)null, (String)null, (String)null, (WebSearchOptions)null); }
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, AudioParameters audio, boolean stream) { this(messages, model, (Boolean)null, (Map)null, (Double)null, (Map)null, (Boolean)null, (Integer)null, (Integer)null, (Integer)null, (Integer)null, List.of(OpenAiApi.OutputModality.AUDIO, OpenAiApi.OutputModality.TEXT), audio, (Double)null, (ResponseFormat)null, (Integer)null, (String)null, (List)null, stream, (StreamOptions)null, (Double)null, (Double)null, (List)null, (Object)null, (Boolean)null, (String)null, (String)null, (WebSearchOptions)null); }
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) { this(messages, model, (Boolean)null, (Map)null, (Double)null, (Map)null, (Boolean)null, (Integer)null, (Integer)null, (Integer)null, (Integer)null, (List)null, (AudioParameters)null, (Double)null, (ResponseFormat)null, (Integer)null, (String)null, (List)null, stream, (StreamOptions)null, temperature, (Double)null, (List)null, (Object)null, (Boolean)null, (String)null, (String)null, (WebSearchOptions)null); }
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, List<FunctionTool> tools, Object toolChoice) { this(messages, model, (Boolean)null, (Map)null, (Double)null, (Map)null, (Boolean)null, (Integer)null, (Integer)null, (Integer)null, (Integer)null, (List)null, (AudioParameters)null, (Double)null, (ResponseFormat)null, (Integer)null, (String)null, (List)null, false, (StreamOptions)null, 0.8, (Double)null, tools, toolChoice, (Boolean)null, (String)null, (String)null, (WebSearchOptions)null); }
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) { this(messages, (String)null, (Boolean)null, (Map)null, (Double)null, (Map)null, (Boolean)null, (Integer)null, (Integer)null, (Integer)null, (Integer)null, (List)null, (AudioParameters)null, (Double)null, (ResponseFormat)null, (Integer)null, (String)null, (List)null, stream, (StreamOptions)null, (Double)null, (Double)null, (List)null, (Object)null, (Boolean)null, (String)null, (String)null, (WebSearchOptions)null); }
public ChatCompletionRequest(@JsonProperty("messages") List<ChatCompletionMessage> messages, @JsonProperty("model") String model, @JsonProperty("store") Boolean store, @JsonProperty("metadata") Map<String, String> metadata, @JsonProperty("frequencypenalty") Double frequencyPenalty, @JsonProperty("logitbias") Map<String, Integer> logitBias, @JsonProperty("logprobs") Boolean logprobs, @JsonProperty("toplogprobs") Integer topLogprobs, @JsonProperty("maxtokens") Integer maxTokens, @JsonProperty("maxcompletiontokens") Integer maxCompletionTokens, @JsonProperty("n") Integer n, @JsonProperty("modalities") List<OutputModality> outputModalities, @JsonProperty("audio") AudioParameters audioParameters, @JsonProperty("presencepenalty") Double presencePenalty, @JsonProperty("responseformat") ResponseFormat responseFormat, @JsonProperty("seed") Integer seed, @JsonProperty("servicetier") String serviceTier, @JsonProperty("stop") List<String> stop, @JsonProperty("stream") Boolean stream, @JsonProperty("streamoptions") StreamOptions streamOptions, @JsonProperty("temperature") Double temperature, @JsonProperty("topp") Double topP, @JsonProperty("tools") List<FunctionTool> tools, @JsonProperty("toolchoice") Object toolChoice, @JsonProperty("paralleltoolcalls") Boolean parallelToolCalls, @JsonProperty("user") String user, @JsonProperty("reasoningeffort") String reasoningEffort, @JsonProperty("websearchoptions") WebSearchOptions webSearchOptions) { this.messages = messages; this.model = model; this.store = store; this.metadata = metadata; this.frequencyPenalty = frequencyPenalty; this.logitBias = logitBias; this.logprobs = logprobs; this.topLogprobs = topLogprobs; this.maxTokens = maxTokens; this.maxCompletionTokens = maxCompletionTokens; this.n = n; this.outputModalities = outputModalities; this.audioParameters = audioParameters; this.presencePenalty = presencePenalty; this.responseFormat = responseFormat; this.seed = seed; this.serviceTier = serviceTier; this.stop = stop; this.stream = stream; this.streamOptions = streamOptions; this.temperature = temperature; this.topP = topP; this.tools = tools; this.toolChoice = toolChoice; this.parallelToolCalls = parallelToolCalls; this.user = user; this.reasoningEffort = reasoningEffort; this.webSearchOptions = webSearchOptions; }
public ChatCompletionRequest streamOptions(StreamOptions streamOptions) { return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty, this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP, this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions); }
@JsonProperty("messages") public List<ChatCompletionMessage> messages() { return this.messages; }
@JsonProperty("model") public String model() { return this.model; }
@JsonProperty("store") public Boolean store() { return this.store; }
@JsonProperty("metadata") public Map<String, String> metadata() { return this.metadata; }
@JsonProperty("frequencypenalty") public Double frequencyPenalty() { return this.frequencyPenalty; }
@JsonProperty("logitbias") public Map<String, Integer> logitBias() { return this.logitBias; }
@JsonProperty("logprobs") public Boolean logprobs() { return this.logprobs; }
@JsonProperty("toplogprobs") public Integer topLogprobs() { return this.topLogprobs; }
@JsonProperty("maxtokens") public Integer maxTokens() { return this.maxTokens; }
@JsonProperty("maxcompletiontokens") public Integer maxCompletionTokens() { return this.maxCompletionTokens; }
@JsonProperty("n") public Integer n() { return this.n; }
@JsonProperty("modalities") public List<OutputModality> outputModalities() { return this.outputModalities; }
@JsonProperty("audio") public AudioParameters audioParameters() { return this.audioParameters; }
@JsonProperty("presencepenalty") public Double presencePenalty() { return this.presencePenalty; }
@JsonProperty("responseformat") public ResponseFormat responseFormat() { return this.responseFormat; }
@JsonProperty("seed") public Integer seed() { return this.seed; }
@JsonProperty("servicetier") public String serviceTier() { return this.serviceTier; }
@JsonProperty("stop") public List<String> stop() { return this.stop; }
@JsonProperty("stream") public Boolean stream() { return this.stream; }
@JsonProperty("streamoptions") public StreamOptions streamOptions() { return this.streamOptions; }
@JsonProperty("temperature") public Double temperature() { return this.temperature; }
@JsonProperty("topp") public Double topP() { return this.topP; }
@JsonProperty("tools") public List<FunctionTool> tools() { return this.tools; }
@JsonProperty("toolchoice") public Object toolChoice() { return this.toolChoice; }
@JsonProperty("paralleltoolcalls") public Boolean parallelToolCalls() { return this.parallelToolCalls; }
@JsonProperty("user") public String user() { return this.user; }
@JsonProperty("reasoningeffort") public String reasoningEffort() { return this.reasoningEffort; }
@JsonProperty("websearchoptions") public WebSearchOptions webSearchOptions() { return this.webSearchOptions; }
public static class ToolChoiceBuilder { public static final String AUTO = "auto"; public static final String NONE = "none";
public static Object FUNCTION(String functionName) { return Map.of("type", "function", "function", Map.of("name", functionName)); } }
@JsonInclude(Include.NONNULL) public static record AudioParameters(Voice voice, AudioResponseFormat format) { public AudioParameters(@JsonProperty("voice") Voice voice, @JsonProperty("format") AudioResponseFormat format) { this.voice = voice; this.format = format; }
@JsonProperty("voice") public Voice voice() { return this.voice; }
@JsonProperty("format") public AudioResponseFormat format() { return this.format; }
public static enum Voice { @JsonProperty("alloy") ALLOY, @JsonProperty("echo") ECHO, @JsonProperty("fable") FABLE, @JsonProperty("onyx") ONYX, @JsonProperty("nova") NOVA, @JsonProperty("shimmer") SHIMMER; }
public static enum AudioResponseFormat { @JsonProperty("mp3") MP3, @JsonProperty("flac") FLAC, @JsonProperty("opus") OPUS, @JsonProperty("pcm16") PCM16, @JsonProperty("wav") WAV; } }
@JsonInclude(Include.NONNULL) public static record StreamOptions(Boolean includeUsage) { public static StreamOptions INCLUDEUSAGE = new StreamOptions(true);
public StreamOptions(@JsonProperty("includeusage") Boolean includeUsage) { this.includeUsage = includeUsage; }
@JsonProperty("includeusage") public Boolean includeUsage() { return this.includeUsage; } }
@JsonInclude(Include.NONNULL) public static record WebSearchOptions(SearchContextSize searchContextSize, UserLocation userLocation) { public WebSearchOptions(@JsonProperty("searchcontextsize") SearchContextSize searchContextSize, @JsonProperty("userlocation") UserLocation userLocation) { this.searchContextSize = searchContextSize; this.userLocation = userLocation; }
@JsonProperty("searchcontextsize") public SearchContextSize searchContextSize() { return this.searchContextSize; }
@JsonProperty("userlocation") public UserLocation userLocation() { return this.userLocation; }
public static enum SearchContextSize { @JsonProperty("low") LOW, @JsonProperty("medium") MEDIUM, @JsonProperty("high") HIGH; }
@JsonInclude(Include.NONNULL) public static record UserLocation(String type, Approximate approximate) { public UserLocation(@JsonProperty("type") String type, @JsonProperty("approximate") Approximate approximate) { this.type = type; this.approximate = approximate; }
@JsonProperty("type") public String type() { return this.type; }
@JsonProperty("approximate") public Approximate approximate() { return this.approximate; }
@JsonInclude(Include.NONNULL) public static record Approximate(String city, String country, String region, String timezone) { public Approximate(@JsonProperty("city") String city, @JsonProperty("country") String country, @JsonProperty("region") String region, @JsonProperty("timezone") String timezone) { this.city = city; this.country = country; this.region = region; this.timezone = timezone; }
@JsonProperty("city") public String city() { return this.city; }
@JsonProperty("country") public String country() { return this.country; }
@JsonProperty("region") public String region() { return this.region; }
@JsonProperty("timezone") public String timezone() { return this.timezone; } } } } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ChatCompletionMessage(Object rawContent, Role role, String name, String toolCallId, List<ToolCall> toolCalls, String refusal, AudioOutput audioOutput, List<Annotation> annotations) { public ChatCompletionMessage(Object content, Role role) { this(content, role, (String)null, (String)null, (List)null, (String)null, (AudioOutput)null, (List)null); }
public ChatCompletionMessage(@JsonProperty("content") Object rawContent, @JsonProperty("role") Role role, @JsonProperty("name") String name, @JsonProperty("toolcallid") String toolCallId, @JsonProperty("toolcalls") @JsonFormat(with = {Feature.ACCEPTSINGLEVALUEASARRAY}) List<ToolCall> toolCalls, @JsonProperty("refusal") String refusal, @JsonProperty("audio") AudioOutput audioOutput, @JsonProperty("annotations") List<Annotation> annotations) { this.rawContent = rawContent; this.role = role; this.name = name; this.toolCallId = toolCallId; this.toolCalls = toolCalls; this.refusal = refusal; this.audioOutput = audioOutput; this.annotations = annotations; }
public String content() { if (this.rawContent == null) { return null; } else { Object var2 = this.rawContent; if (var2 instanceof String) { String text = (String)var2; return text; } else { throw new IllegalStateException("The content is not a string!"); } } }
@JsonProperty("content") public Object rawContent() { return this.rawContent; }
@JsonProperty("role") public Role role() { return this.role; }
@JsonProperty("name") public String name() { return this.name; }
@JsonProperty("toolcallid") public String toolCallId() { return this.toolCallId; }
@JsonProperty("toolcalls") @JsonFormat( with = {Feature.ACCEPTSINGLEVALUEASARRAY} ) public List<ToolCall> toolCalls() { return this.toolCalls; }
@JsonProperty("refusal") public String refusal() { return this.refusal; }
@JsonProperty("audio") public AudioOutput audioOutput() { return this.audioOutput; }
@JsonProperty("annotations") public List<Annotation> annotations() { return this.annotations; }
public static enum Role { @JsonProperty("system") SYSTEM, @JsonProperty("user") USER, @JsonProperty("assistant") ASSISTANT, @JsonProperty("tool") TOOL; }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record MediaContent(String type, String text, ImageUrl imageUrl, InputAudio inputAudio) { public MediaContent(String text) { this("text", text, (ImageUrl)null, (InputAudio)null); }
public MediaContent(ImageUrl imageUrl) { this("imageurl", (String)null, imageUrl, (InputAudio)null); }
public MediaContent(InputAudio inputAudio) { this("inputaudio", (String)null, (ImageUrl)null, inputAudio); }
public MediaContent(@JsonProperty("type") String type, @JsonProperty("text") String text, @JsonProperty("imageurl") ImageUrl imageUrl, @JsonProperty("inputaudio") InputAudio inputAudio) { this.type = type; this.text = text; this.imageUrl = imageUrl; this.inputAudio = inputAudio; }
@JsonProperty("type") public String type() { return this.type; }
@JsonProperty("text") public String text() { return this.text; }
@JsonProperty("imageurl") public ImageUrl imageUrl() { return this.imageUrl; }
@JsonProperty("inputaudio") public InputAudio inputAudio() { return this.inputAudio; }
@JsonInclude(Include.NONNULL) public static record InputAudio(String data, Format format) { public InputAudio(@JsonProperty("data") String data, @JsonProperty("format") Format format) { this.data = data; this.format = format; }
@JsonProperty("data") public String data() { return this.data; }
@JsonProperty("format") public Format format() { return this.format; }
public static enum Format { @JsonProperty("mp3") MP3, @JsonProperty("wav") WAV; } }
@JsonInclude(Include.NONNULL) public static record ImageUrl(String url, String detail) { public ImageUrl(String url) { this(url, (String)null); }
public ImageUrl(@JsonProperty("url") String url, @JsonProperty("detail") String detail) { this.url = url; this.detail = detail; }
@JsonProperty("url") public String url() { return this.url; }
@JsonProperty("detail") public String detail() { return this.detail; } } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ToolCall(Integer index, String id, String type, ChatCompletionFunction function) { public ToolCall(String id, String type, ChatCompletionFunction function) { this((Integer)null, id, type, function); }
public ToolCall(@JsonProperty("index") Integer index, @JsonProperty("id") String id, @JsonProperty("type") String type, @JsonProperty("function") ChatCompletionFunction function) { this.index = index; this.id = id; this.type = type; this.function = function; }
@JsonProperty("index") public Integer index() { return this.index; }
@JsonProperty("id") public String id() { return this.id; }
@JsonProperty("type") public String type() { return this.type; }
@JsonProperty("function") public ChatCompletionFunction function() { return this.function; } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ChatCompletionFunction(String name, String arguments) { public ChatCompletionFunction(@JsonProperty("name") String name, @JsonProperty("arguments") String arguments) { this.name = name; this.arguments = arguments; }
@JsonProperty("name") public String name() { return this.name; }
@JsonProperty("arguments") public String arguments() { return this.arguments; } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record AudioOutput(String id, String data, Long expiresAt, String transcript) { public AudioOutput(@JsonProperty("id") String id, @JsonProperty("data") String data, @JsonProperty("expiresat") Long expiresAt, @JsonProperty("transcript") String transcript) { this.id = id; this.data = data; this.expiresAt = expiresAt; this.transcript = transcript; }
@JsonProperty("id") public String id() { return this.id; }
@JsonProperty("data") public String data() { return this.data; }
@JsonProperty("expiresat") public Long expiresAt() { return this.expiresAt; }
@JsonProperty("transcript") public String transcript() { return this.transcript; } }
@JsonInclude(Include.NONNULL) public static record Annotation(String type, UrlCitation urlCitation) { public Annotation(@JsonProperty("type") String type, @JsonProperty("urlcitation") UrlCitation urlCitation) { this.type = type; this.urlCitation = urlCitation; }
@JsonProperty("type") public String type() { return this.type; }
@JsonProperty("urlcitation") public UrlCitation urlCitation() { return this.urlCitation; }
@JsonInclude(Include.NONNULL) public static record UrlCitation(Integer endIndex, Integer startIndex, String title, String url) { public UrlCitation(@JsonProperty("endindex") Integer endIndex, @JsonProperty("startindex") Integer startIndex, @JsonProperty("title") String title, @JsonProperty("url") String url) { this.endIndex = endIndex; this.startIndex = startIndex; this.title = title; this.url = url; }
@JsonProperty("endindex") public Integer endIndex() { return this.endIndex; }
@JsonProperty("startindex") public Integer startIndex() { return this.startIndex; }
@JsonProperty("title") public String title() { return this.title; }
@JsonProperty("url") public String url() { return this.url; } } } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ChatCompletion(String id, List<Choice> choices, Long created, String model, String serviceTier, String systemFingerprint, String object, Usage usage) { public ChatCompletion(@JsonProperty("id") String id, @JsonProperty("choices") List<Choice> choices, @JsonProperty("created") Long created, @JsonProperty("model") String model, @JsonProperty("servicetier") String serviceTier, @JsonProperty("systemfingerprint") String systemFingerprint, @JsonProperty("object") String object, @JsonProperty("usage") Usage usage) { this.id = id; this.choices = choices; this.created = created; this.model = model; this.serviceTier = serviceTier; this.systemFingerprint = systemFingerprint; this.object = object; this.usage = usage; }
@JsonProperty("id") public String id() { return this.id; }
@JsonProperty("choices") public List<Choice> choices() { return this.choices; }
@JsonProperty("created") public Long created() { return this.created; }
@JsonProperty("model") public String model() { return this.model; }
@JsonProperty("servicetier") public String serviceTier() { return this.serviceTier; }
@JsonProperty("systemfingerprint") public String systemFingerprint() { return this.systemFingerprint; }
@JsonProperty("object") public String object() { return this.object; }
@JsonProperty("usage") public Usage usage() { return this.usage; }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record Choice(ChatCompletionFinishReason finishReason, Integer index, ChatCompletionMessage message, LogProbs logprobs) { public Choice(@JsonProperty("finishreason") ChatCompletionFinishReason finishReason, @JsonProperty("index") Integer index, @JsonProperty("message") ChatCompletionMessage message, @JsonProperty("logprobs") LogProbs logprobs) { this.finishReason = finishReason; this.index = index; this.message = message; this.logprobs = logprobs; }
@JsonProperty("finishreason") public ChatCompletionFinishReason finishReason() { return this.finishReason; }
@JsonProperty("index") public Integer index() { return this.index; }
@JsonProperty("message") public ChatCompletionMessage message() { return this.message; }
@JsonProperty("logprobs") public LogProbs logprobs() { return this.logprobs; } } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record LogProbs(List<Content> content, List<Content> refusal) { public LogProbs(@JsonProperty("content") List<Content> content, @JsonProperty("refusal") List<Content> refusal) { this.content = content; this.refusal = refusal; }
@JsonProperty("content") public List<Content> content() { return this.content; }
@JsonProperty("refusal") public List<Content> refusal() { return this.refusal; }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record Content(String token, Float logprob, List<Integer> probBytes, List<TopLogProbs> topLogprobs) { public Content(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob, @JsonProperty("bytes") List<Integer> probBytes, @JsonProperty("toplogprobs") List<TopLogProbs> topLogprobs) { this.token = token; this.logprob = logprob; this.probBytes = probBytes; this.topLogprobs = topLogprobs; }
@JsonProperty("token") public String token() { return this.token; }
@JsonProperty("logprob") public Float logprob() { return this.logprob; }
@JsonProperty("bytes") public List<Integer> probBytes() { return this.probBytes; }
@JsonProperty("toplogprobs") public List<TopLogProbs> topLogprobs() { return this.topLogprobs; }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record TopLogProbs(String token, Float logprob, List<Integer> probBytes) { public TopLogProbs(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob, @JsonProperty("bytes") List<Integer> probBytes) { this.token = token; this.logprob = logprob; this.probBytes = probBytes; }
@JsonProperty("token") public String token() { return this.token; }
@JsonProperty("logprob") public Float logprob() { return this.logprob; }
@JsonProperty("bytes") public List<Integer> probBytes() { return this.probBytes; } } } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens, PromptTokensDetails promptTokensDetails, CompletionTokenDetails completionTokenDetails) { public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { this(completionTokens, promptTokens, totalTokens, (PromptTokensDetails)null, (CompletionTokenDetails)null); }
public Usage(@JsonProperty("completiontokens") Integer completionTokens, @JsonProperty("prompttokens") Integer promptTokens, @JsonProperty("totaltokens") Integer totalTokens, @JsonProperty("prompttokensdetails") PromptTokensDetails promptTokensDetails, @JsonProperty("completiontokensdetails") CompletionTokenDetails completionTokenDetails) { this.completionTokens = completionTokens; this.promptTokens = promptTokens; this.totalTokens = totalTokens; this.promptTokensDetails = promptTokensDetails; this.completionTokenDetails = completionTokenDetails; }
@JsonProperty("completiontokens") public Integer completionTokens() { return this.completionTokens; }
@JsonProperty("prompttokens") public Integer promptTokens() { return this.promptTokens; }
@JsonProperty("totaltokens") public Integer totalTokens() { return this.totalTokens; }
@JsonProperty("prompttokensdetails") public PromptTokensDetails promptTokensDetails() { return this.promptTokensDetails; }
@JsonProperty("completiontokensdetails") public CompletionTokenDetails completionTokenDetails() { return this.completionTokenDetails; }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record PromptTokensDetails(Integer audioTokens, Integer cachedTokens) { public PromptTokensDetails(@JsonProperty("audiotokens") Integer audioTokens, @JsonProperty("cachedtokens") Integer cachedTokens) { this.audioTokens = audioTokens; this.cachedTokens = cachedTokens; }
@JsonProperty("audiotokens") public Integer audioTokens() { return this.audioTokens; }
@JsonProperty("cachedtokens") public Integer cachedTokens() { return this.cachedTokens; } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record CompletionTokenDetails(Integer reasoningTokens, Integer acceptedPredictionTokens, Integer audioTokens, Integer rejectedPredictionTokens) { public CompletionTokenDetails(@JsonProperty("reasoningtokens") Integer reasoningTokens, @JsonProperty("acceptedpredictiontokens") Integer acceptedPredictionTokens, @JsonProperty("audiotokens") Integer audioTokens, @JsonProperty("rejectedpredictiontokens") Integer rejectedPredictionTokens) { this.reasoningTokens = reasoningTokens; this.acceptedPredictionTokens = acceptedPredictionTokens; this.audioTokens = audioTokens; this.rejectedPredictionTokens = rejectedPredictionTokens; }
@JsonProperty("reasoningtokens") public Integer reasoningTokens() { return this.reasoningTokens; }
@JsonProperty("acceptedpredictiontokens") public Integer acceptedPredictionTokens() { return this.acceptedPredictionTokens; }
@JsonProperty("audiotokens") public Integer audioTokens() { return this.audioTokens; }
@JsonProperty("rejectedpredictiontokens") public Integer rejectedPredictionTokens() { return this.rejectedPredictionTokens; } } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ChatCompletionChunk(String id, List<ChunkChoice> choices, Long created, String model, String serviceTier, String systemFingerprint, String object, Usage usage) { public ChatCompletionChunk(@JsonProperty("id") String id, @JsonProperty("choices") List<ChunkChoice> choices, @JsonProperty("created") Long created, @JsonProperty("model") String model, @JsonProperty("servicetier") String serviceTier, @JsonProperty("systemfingerprint") String systemFingerprint, @JsonProperty("object") String object, @JsonProperty("usage") Usage usage) { this.id = id; this.choices = choices; this.created = created; this.model = model; this.serviceTier = serviceTier; this.systemFingerprint = systemFingerprint; this.object = object; this.usage = usage; }
@JsonProperty("id") public String id() { return this.id; }
@JsonProperty("choices") public List<ChunkChoice> choices() { return this.choices; }
@JsonProperty("created") public Long created() { return this.created; }
@JsonProperty("model") public String model() { return this.model; }
@JsonProperty("servicetier") public String serviceTier() { return this.serviceTier; }
@JsonProperty("systemfingerprint") public String systemFingerprint() { return this.systemFingerprint; }
@JsonProperty("object") public String object() { return this.object; }
@JsonProperty("usage") public Usage usage() { return this.usage; }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record ChunkChoice(ChatCompletionFinishReason finishReason, Integer index, ChatCompletionMessage delta, LogProbs logprobs) { public ChunkChoice(@JsonProperty("finishreason") ChatCompletionFinishReason finishReason, @JsonProperty("index") Integer index, @JsonProperty("delta") ChatCompletionMessage delta, @JsonProperty("logprobs") LogProbs logprobs) { this.finishReason = finishReason; this.index = index; this.delta = delta; this.logprobs = logprobs; }
@JsonProperty("finishreason") public ChatCompletionFinishReason finishReason() { return this.finishReason; }
@JsonProperty("index") public Integer index() { return this.index; }
@JsonProperty("delta") public ChatCompletionMessage delta() { return this.delta; }
@JsonProperty("logprobs") public LogProbs logprobs() { return this.logprobs; } } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record Embedding(Integer index, float[] embedding, String object) { public Embedding(Integer index, float[] embedding) { this(index, embedding, "embedding"); }
public Embedding(@JsonProperty("index") Integer index, @JsonProperty("embedding") float[] embedding, @JsonProperty("object") String object) { this.index = index; this.embedding = embedding; this.object = object; }
@JsonProperty("index") public Integer index() { return this.index; }
@JsonProperty("embedding") public float[] embedding() { return this.embedding; }
@JsonProperty("object") public String object() { return this.object; } }
@JsonInclude(Include.NONNULL) public static record EmbeddingRequest<T>(T input, String model, String encodingFormat, Integer dimensions, String user) { public EmbeddingRequest(T input, String model) { this(input, model, "float", (Integer)null, (String)null); }
public EmbeddingRequest(T input) { this(input, OpenAiApi.DEFAULTEMBEDDINGMODEL); }
public EmbeddingRequest(@JsonProperty("input") T input, @JsonProperty("model") String model, @JsonProperty("encodingformat") String encodingFormat, @JsonProperty("dimensions") Integer dimensions, @JsonProperty("user") String user) { this.input = input; this.model = model; this.encodingFormat = encodingFormat; this.dimensions = dimensions; this.user = user; }
@JsonProperty("input") public T input() { return this.input; }
@JsonProperty("model") public String model() { return this.model; }
@JsonProperty("encodingformat") public String encodingFormat() { return this.encodingFormat; }
@JsonProperty("dimensions") public Integer dimensions() { return this.dimensions; }
@JsonProperty("user") public String user() { return this.user; } }
@JsonInclude(Include.NONNULL) @JsonIgnoreProperties( ignoreUnknown = true ) public static record EmbeddingList<T>(String object, List<T> data, String model, Usage usage) { public EmbeddingList(@JsonProperty("object") String object, @JsonProperty("data") List<T> data, @JsonProperty("model") String model, @JsonProperty("usage") Usage usage) { this.object = object; this.data = data; this.model = model; this.usage = usage; }
@JsonProperty("object") public String object() { return this.object; }
@JsonProperty("data") public List<T> data() { return this.data; }
@JsonProperty("model") public String model() { return this.model; }
@JsonProperty("usage") public Usage usage() { return this.usage; } }
public static class Builder { private String baseUrl = "https://api.openai.com"; private ApiKey apiKey; private MultiValueMap<String, String> headers = new LinkedMultiValueMap(); private String completionsPath = "/v1/chat/completions"; private String embeddingsPath = "/v1/embeddings"; private RestClient.Builder restClientBuilder = RestClient.builder(); private WebClient.Builder webClientBuilder = WebClient.builder(); private ResponseErrorHandler responseErrorHandler;
public Builder() { this.responseErrorHandler = RetryUtils.DEFAULTRESPONSEERRORHANDLER; }
public Builder(OpenAiApi api) { this.responseErrorHandler = RetryUtils.DEFAULTRESPONSEERRORHANDLER; this.baseUrl = api.getBaseUrl(); this.apiKey = api.getApiKey(); this.headers = new LinkedMultiValueMap(api.getHeaders()); this.completionsPath = api.getCompletionsPath(); this.embeddingsPath = api.getEmbeddingsPath(); this.restClientBuilder = api.restClient != null ? api.restClient.mutate() : RestClient.builder(); this.webClientBuilder = api.webClient != null ? api.webClient.mutate() : WebClient.builder(); this.responseErrorHandler = api.getResponseErrorHandler(); }
public Builder baseUrl(String baseUrl) { Assert.hasText(baseUrl, "baseUrl cannot be null or empty"); this.baseUrl = baseUrl; return this; }
public Builder apiKey(ApiKey apiKey) { Assert.notNull(apiKey, "apiKey cannot be null"); this.apiKey = apiKey; return this; }
public Builder apiKey(String simpleApiKey) { Assert.notNull(simpleApiKey, "simpleApiKey cannot be null"); this.apiKey = new SimpleApiKey(simpleApiKey); return this; }
public Builder headers(MultiValueMap<String, String> headers) { Assert.notNull(headers, "headers cannot be null"); this.headers = headers; return this; }
public Builder completionsPath(String completionsPath) { Assert.hasText(completionsPath, "completionsPath cannot be null or empty"); this.completionsPath = completionsPath; return this; }
public Builder embeddingsPath(String embeddingsPath) { Assert.hasText(embeddingsPath, "embeddingsPath cannot be null or empty"); this.embeddingsPath = embeddingsPath; return this; }
public Builder restClientBuilder(RestClient.Builder restClientBuilder) { Assert.notNull(restClientBuilder, "restClientBuilder cannot be null"); this.restClientBuilder = restClientBuilder; return this; }
public Builder webClientBuilder(WebClient.Builder webClientBuilder) { Assert.notNull(webClientBuilder, "webClientBuilder cannot be null"); this.webClientBuilder = webClientBuilder; return this; }
public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null"); this.responseErrorHandler = responseErrorHandler; return this; }
public OpenAiApi build() { Assert.notNull(this.apiKey, "apiKey must be set"); return new OpenAiApi(this.baseUrl, this.apiKey, this.headers, this.completionsPath, this.embeddingsPath, this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler); } }}