Step 7: Gather Credentials
Collect the following credentials:
- API Key (from "Keys and Endpoint" section)
- Endpoint URL (https://javainuse.openai.azure.com/)
- Deployment ID (your chosen deployment name)
Integrate Azure OpenAI Text Embedding Model to generate vector embeddings
Next we will be modifying the spring boot project to integrate the deployed openai text embedding model.
In application.properties add the embedding model id
azure.openai.api.key=66T7YqY4CaqXB9JjTfiD6q1RTz9p45goTXeVAJPpWKERdclRxRLgJQQJ99BBACYeBjFXJ3w3AAABACOGmCSx
azure.openai.endpoint=https://javainuse-service.openai.azure.com/
azure.openai.deployment.model.id=gpt-4o
azure.openai.embedding.model.id=text-embedding-3-small
Create a new model class named
EmbedData which we will be using to pass the data for which we want the vector embeddings from the azure text embedding model.
package com.javainuse.spring_boot_ai;
public class EmbedData {
private String data;
public String getData() {
return data;
}
public void setData(String data) {
this.data = data;
}
}
Next we will be exposing a
POST API with url
/get-embed. For this request we send the
EmbedData instance as request body. Using the data passed in
EmbedData we call
the deployed azure text embedding model to get the vector embeddings for the data.
package com.azure.ai.openai.usage;
import java.util.ArrayList;
import java.util.List;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatMessage;
import com.azure.ai.openai.models.ChatRole;
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
import com.azure.core.credential.AzureKeyCredential;
import android.app.Application;
@RestController
public class PromptController {
@Value("${azure.openai.api.key}")
private String azureOpenaiKey;
@Value("${azure.openai.endpoint}")
private String endpoint;
@Value("${azure.openai.deployment.model.id}")
private String deploymentOrModelId;
@Value("${azure.openai.embedding.model.id}")
private String embeddingModelId;
@PostMapping("/answer")
public List<String> getMethodName(@RequestBody PromptQuestion promptQuestion) {
List<String> responseList = new ArrayList<>();
try {
OpenAIClient client = new OpenAIClientBuilder().endpoint(endpoint)
.credential(new AzureKeyCredential(azureOpenaiKey)).buildClient();
List<ChatMessage> messages = new ArrayList<>();
messages.add(new ChatMessage(ChatRole.SYSTEM)
.setContent("You are an AI assistant that helps people find information."));
messages.add(new ChatMessage(ChatRole.USER).setContent(getPrompt(promptQuestion)));
ChatCompletionsOptions options = new ChatCompletionsOptions(messages).setTemperature(0.7).setTopP(0.95)
.setMaxTokens(800);
ChatCompletions completions = client.getChatCompletions(deploymentOrModelId, options);
for (ChatChoice choice : completions.getChoices()) {
responseList.add(choice.getMessage().getContent().trim());
}
} catch (Exception ex) {
ex.printStackTrace();
responseList.add("Exception Occurred");
}
return responseList;
}
@PostMapping("/get-embed")
public List<Double> getEmbedding(@RequestBody EmbedData embedData) {
try {
OpenAIClient client = new OpenAIClientBuilder().endpoint(endpoint)
.credential(new AzureKeyCredential(azureOpenaiKey)).buildClient();
String input = getEmbedData(embedData);
Embeddings embeddings = client.getEmbeddings(embeddingModelId,
new com.azure.ai.openai.models.EmbeddingsOptions(List.of(input)));
EmbeddingItem embeddingItem = embeddings.getData().get(0);
List<Double> embedding = embeddingItem.getEmbedding();
return embedding;
} catch (Exception ex) {
ex.printStackTrace();
return List.of();
}
}
private String getPrompt(PromptQuestion promptQuestion) {
String input = promptQuestion.getQuestion().trim();
return input;
}
private String getEmbedData(EmbedData embedData) {
String input = embedData.getData().trim();
return input;
}
}
If we now start the application, we can test the get embeddings call - http://localhost:8080/get-embed
{
"data":"Tesla sales crash in Europe, dropping 45% in Jan 2025 as EV competition heats up"
}
Store vector embeddings in elasticsearch
In previous tutorial we had implemented
Spring Boot + Elasticsearch CRUD example.
Next we will be integrating elasticsearch in our application. We will be using elasticsearch as a vectore database. We will be storing the vector embedding in elasticsearch.
Go to the
elasticsearch downloads page. Click on the Windows button to download
the latest elasticsearch installable. In our case it is 8.8.2.

This will be a zip folder. Unzip it as follows.

Open the command prompt as an admin. Go to the elasticsearch bin folder and type the following command
elasticsearch.bat
This will start elasticsearch.
With elasticsearch 8 security is enabled by default. So https and password authentication is enabled by default.
Using a browser if we now go to
https://localhost:9200 we get the following prompt asking for a username and password.

Next we will be setting the elasticsearch password. For this do not close the previous command prompt window which is
running elasticsearch. Open another command prompt as as admin and go to the bin folder.
We will be using the following command for setting the elasticsearch password.
elasticsearch-reset-password -u elastic --interactive
I have set the elastic password as password.
So if i now go to the elasticsearch url -
localhost:9200 I can enter the username as elastic and password as password.
elasticsearch.

We are able to access the portal as follows -
Add the elasticsearch dependency to pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.4.2</version>
<relativePath /> <!-- lookup parent from repository -->
</parent>
<groupId>com.javainuse</groupId>
<artifactId>spring-ai-openai</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>spring-boot-ai</name>
<description>Demo project for Spring Boot</description>
<properties>
<java.version>17</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-ai-openai</artifactId>
<version>1.0.0-beta.1</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-elasticsearch</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>co.elastic.clients</groupId>
<artifactId>elasticsearch-java</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
By default elasticsearch runs on https. For this we will be create a truststore file. This truststore file is then used by your spring boot application to establish a secure SSL/TLS connection with
the Elasticsearch cluster.
In this truststore we will be importing the CA certificate of elasticsearch.
When your application attempts to connect to Elasticsearch, it will use the CA certificate in the truststore to verify
the identity of the Elasticsearch server and establish a trusted connection. For this go to the elasticsearch jdk/bin installation and run the following command
.\keytool.exe -import -file "E:\trial\elasticsearch-8.13.4-windows-x86_64\elasticsearch-8.13.4\config\certs\http_ca.crt" -keystore "E:\trial\elasticsearch-8.13.4-windows-x86_64\elasticsearch-8.13.4\config\certs\truststore.p12" -storepass javainuse -noprompt -storetype pkcs12
The
HttpClientConfigImpl class is a Spring configuration class that implements the HttpClientConfigCallback interface. This interface allows developers to customize the HttpAsyncClientBuilder, which is responsible for creating the client used to communicate with Elasticsearch.
The HttpClientConfigImpl class performs the following tasks:
-
Authentication Setup: The class creates a CredentialsProvider object and sets the username and password credentials required for authentication with the Elasticsearch cluster.
In the provided example, the username is "elastic" and the password is "javainuse".
-
SSL/TLS Configuration: To establish a secure connection with the Elasticsearch cluster, the class loads an SSL/TLS truststore from a specified file path (D:\elasticsearch-8.13.4-windows-x86\_64\elasticsearch-8.13.4\config\certs\truststore.p12). The truststore is a file that contains trusted Certificate Authorities (CAs) used for verifying the server's identity during the SSL/TLS handshake.
The class creates an SSLContext object using the SSLContextBuilder and loads the truststore file into it, using the provided password ("password" in this case).
-
Client Configuration: Finally, the class sets the CredentialsProvider and SSLContext on the HttpAsyncClientBuilder. This builder is then used to create the Elasticsearch client, ensuring that the client
can authenticate with the Elasticsearch cluster using the provided credentials and establish a secure connection using the specified truststore.
package com.javainuse.spring_boot_ai.config;
import java.io.File;
import javax.net.ssl.SSLContext;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.ssl.SSLContexts;
import org.elasticsearch.client.RestClientBuilder.HttpClientConfigCallback;
import org.springframework.context.annotation.Configuration;
@Configuration
public class HttpClientConfigImpl implements HttpClientConfigCallback {
@Override
public HttpAsyncClientBuilder customizeHttpClient(HttpAsyncClientBuilder httpClientBuilder) {
try {
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
UsernamePasswordCredentials usernamePasswordCredentials = new UsernamePasswordCredentials("elastic",
"password");
credentialsProvider.setCredentials(AuthScope.ANY, usernamePasswordCredentials);
httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
String trustLocationStore = "E:\\trial\\elasticsearch-8.13.4-windows-x86_64\\elasticsearch-8.13.4\\config\\certs\\truststore.p12";
File trustLocationFile = new File(trustLocationStore);
SSLContextBuilder sslContextBuilder = SSLContexts.custom().loadTrustMaterial(trustLocationFile,
"javainuse".toCharArray());
SSLContext sslContext = sslContextBuilder.build();
httpClientBuilder.setSSLContext(sslContext);
} catch (Exception e) {
}
return httpClientBuilder;
}
}
Previously we discussed how to configure the Elasticsearch client to establish a secure and authenticated connection with the Elasticsearch cluster.
However, to fully integrate Elasticsearch into our Spring Boot application, we need to create a bean that encapsulates the client configuration and provides an instance of the ElasticsearchClient class.
The ESClient class, annotated with @Component, serves this purpose. Let's break down the getElasticsearchClient method within this class:
-
Initializing the RestClientBuilder:
The RestClientBuilder is initialized with an HttpHost object, which specifies the hostname, port, and protocol for connecting to the Elasticsearch cluster. In this example, the client is configured to connect to a local Elasticsearch instance running on http://localhost:9200.
-
Setting the HttpClientConfigCallback:
The HttpClientConfigCallback implementation, HttpClientConfigImpl, is instantiated and set on the RestClientBuilder. This callback is responsible for configuring the HttpAsyncClientBuilder with the necessary credentials and SSL/TLS settings, as discussed in the previous section.
-
Building the RestClient:
The RestClient instance is created using the configured RestClientBuilder.
-
Creating the RestClientTransport:
The RestClientTransport is a wrapper around the RestClient that provides a low-level transport layer for communicating with the Elasticsearch cluster. It is initialized with the RestClient instance and a JacksonJsonpMapper for handling JSON serialization and deserialization.
-
Instantiating the ElasticsearchClient:
Finally, an ElasticsearchClient instance is created using the RestClientTransport. This client provides a high-level API for interacting with the Elasticsearch cluster, allowing developers to perform various operations such as indexing documents, executing searches, and managing indices.
package com.javainuse.spring_boot_ai.config;
import org.apache.http.HttpHost;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.stereotype.Component;
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
import co.elastic.clients.transport.rest_client.RestClientTransport;
@Component
public class ESClient {
@Bean
public ElasticsearchClient getElasticsearchClient() {
RestClientBuilder builder = RestClient.builder(new HttpHost("localhost", 9200, "https"));
RestClientBuilder.HttpClientConfigCallback httpClientConfigCallback = new HttpClientConfigImpl();
builder.setHttpClientConfigCallback(httpClientConfigCallback);
RestClient restClient = builder.build();
RestClientTransport restClientTransport = new RestClientTransport(restClient, new JacksonJsonpMapper());
return new ElasticsearchClient(restClientTransport);
}
}
Create EmbeddingDocument as follows. We define the embedding field as FieldType.Dense_Vector. A dense vector is a feature primarily used for machine learning, semantic search, and similarity matching.
package com.javainuse.spring_boot_ai.elasticsearch;
import org.springframework.data.annotation.Id;
import org.springframework.data.elasticsearch.annotations.Document;
import org.springframework.data.elasticsearch.annotations.Field;
import org.springframework.data.elasticsearch.annotations.FieldType;
import java.util.List;
@Document(indexName = "embeddings")
public class EmbeddingDocument {
@Id
private String id;
@Field(type = FieldType.Text)
private String text;
@Field(type = FieldType.Dense_Vector, dims = 1536) // Default OpenAI embedding dimension is 1536
private List<Double> embedding;
public EmbeddingDocument() {
}
public EmbeddingDocument(String id, String text, List<Double> embedding) {
this.id = id;
this.text = text;
this.embedding = embedding;
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getText() {
return text;
}
public void setText(String text) {
this.text = text;
}
public List<Double> getEmbedding() {
return embedding;
}
public void setEmbedding(List<Double> embedding) {
this.embedding = embedding;
}
}
Create EmbeddingDocumentDto as follows-
package com.javainuse.spring_boot_ai.elasticsearch;
import java.util.List;
public class EmbeddingDto {
private String id;
private String text;
private List<Double> embedding;
public EmbeddingDto() {
}
public EmbeddingDto(String id, String text, List<Double> embedding) {
this.id = id;
this.text = text;
this.embedding = embedding;
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getText() {
return text;
}
public void setText(String text) {
this.text = text;
}
public List<Double> getEmbedding() {
return embedding;
}
public void setEmbedding(List<Double> embedding) {
this.embedding = embedding;
}
}
Next we define the Spring Data Elasticsearch repository interface for the EmbeddingDocument entity.
By extending ElasticsearchRepository, it automatically provides standard CRUD (Create, Read, Update, Delete) operations for
Elasticsearch without requiring manual implementation. The <EmbeddingDocument, String> type parameters specify the document type and
the type of its primary key (String). The @Repository annotation marks it as a Spring repository, enabling Spring's component scanning and exception translation features for database interactions.
package com.javainuse.spring_boot_ai.elasticsearch;
import org.springframework.data.elasticsearch.repository.ElasticsearchRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface EmbeddingRepository extends ElasticsearchRepository<EmbeddingDocument, String> {
}
Next we create a service to make use of the EmbeddingRepository to store the embeddings in elasticsearch.
package com.azure.ai.openai.usage.elasticsearch;
import java.util.UUID;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
public class EmbeddingService {
@Autowired
private EmbeddingRepository embeddingRepository;
/**
* Store a new embedding
*/
public EmbeddingDto storeEmbedding(EmbeddingDto embeddingDto) {
if (embeddingDto.getId() == null || embeddingDto.getId().isEmpty()) {
embeddingDto.setId(UUID.randomUUID().toString());
}
EmbeddingDocument document = mapToDocument(embeddingDto);
EmbeddingDocument savedDocument = embeddingRepository.save(document);
return mapToDto(savedDocument);
}
private EmbeddingDocument mapToDocument(EmbeddingDto dto) {
return new EmbeddingDocument(dto.getId(), dto.getText(), dto.getEmbedding());
}
private EmbeddingDto mapToDto(EmbeddingDocument document) {
return new EmbeddingDto(document.getId(), document.getText(), document.getEmbedding());
}
}
Finally we expose a REST endpoint, to save the embeddings in elasticsearch using the
EmbeddingService.
package com.javainuse.spring_boot_ai;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatMessage;
import com.azure.ai.openai.models.ChatRole;
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
import com.azure.ai.openai.usage.elasticsearch.EmbeddingDto;
import com.azure.ai.openai.usage.elasticsearch.EmbeddingService;
import com.azure.core.credential.AzureKeyCredential;
@RestController
public class PromptController {
@Value("${azure.openai.api.key}")
private String azureOpenaiKey;
@Value("${azure.openai.endpoint}")
private String endpoint;
@Value("${azure.openai.deployment.model.id}")
private String deploymentOrModelId;
@Value("${azure.openai.embedding.model.id:deploymentOrModelId}")
private String embeddingModelId;
@Autowired
private EmbeddingService embeddingService;
@PostMapping("/answer")
public List<String> getMethodName(@RequestBody PromptQuestion promptQuestion) {
List<String> responseList = new ArrayList<>();
try {
OpenAIClient client = new OpenAIClientBuilder().endpoint(endpoint)
.credential(new AzureKeyCredential(azureOpenaiKey)).buildClient();
List<ChatMessage> messages = new ArrayList<>();
messages.add(new ChatMessage(ChatRole.SYSTEM)
.setContent("You are an AI assistant that helps people find information."));
messages.add(new ChatMessage(ChatRole.USER).setContent(getPrompt(promptQuestion)));
ChatCompletionsOptions options = new ChatCompletionsOptions(messages).setTemperature(0.7).setTopP(0.95)
.setMaxTokens(800);
ChatCompletions completions = client.getChatCompletions(deploymentOrModelId, options);
for (ChatChoice choice : completions.getChoices()) {
responseList.add(choice.getMessage().getContent().trim());
}
} catch (Exception ex) {
ex.printStackTrace();
responseList.add("Exception Occurred");
}
return responseList;
}
@PostMapping("/get-embed")
public List<Double> getEmbedding(@RequestBody EmbedData embedData) {
try {
OpenAIClient client = new OpenAIClientBuilder().endpoint(endpoint)
.credential(new AzureKeyCredential(azureOpenaiKey)).buildClient();
String input = getEmbedData(embedData);
Embeddings embeddings = client.getEmbeddings(embeddingModelId,
new com.azure.ai.openai.models.EmbeddingsOptions(List.of(input)));
EmbeddingItem embeddingItem = embeddings.getData().get(0);
List<Double> embedding = embeddingItem.getEmbedding();
return embedding;
} catch (Exception ex) {
ex.printStackTrace();
return List.of();
}
}
@PostMapping("/store-embed")
public ResponseEntity<EmbeddingDto> storeEmbedding(@RequestBody EmbedData embedData) {
try {
String text = getEmbedData(embedData);
List<Double> embedding = getEmbedding(embedData);
if (embedding.isEmpty()) {
return new ResponseEntity<>(HttpStatus.INTERNAL_SERVER_ERROR);
}
EmbeddingDto embeddingDto = new EmbeddingDto();
embeddingDto.setId(UUID.randomUUID().toString());
embeddingDto.setText(text);
embeddingDto.setEmbedding(embedding);
EmbeddingDto savedEmbedding = embeddingService.storeEmbedding(embeddingDto);
return new ResponseEntity<>(savedEmbedding, HttpStatus.CREATED);
} catch (Exception ex) {
ex.printStackTrace();
return new ResponseEntity<>(HttpStatus.INTERNAL_SERVER_ERROR);
}
}
private String getPrompt(PromptQuestion promptQuestion) {
String input = promptQuestion.getQuestion().trim();
return input;
}
private String getEmbedData(EmbedData embedData) {
String input = embedData.getData().trim();
return input;
}
}
Retrieve the embeddings and use it along with the text to be sent to openai gpt4o
In EmbeddingRepository add the findSimilarDocuments method. This method Takes a query vector as input
Compares it against the "embedding" field of documents.Returns a list of up to 5 most similar documents
it uses cosine similarity or other vector similarity metric.
Query Parameters:
-
field: "embedding" - the vector field to search against
-
query_vector: The input vector to compare (first method parameter)
-
k: 5 - return top 5 most similar documents
-
num_candidates: 100 - initial search space of 100 candidates
package com.azure.ai.openai.usage.elasticsearch;
import java.util.List;
import org.springframework.data.elasticsearch.annotations.Query;
import org.springframework.data.elasticsearch.repository.ElasticsearchRepository;
import org.springframework.stereotype.Repository;
@Repository
public interface EmbeddingRepository extends ElasticsearchRepository<EmbeddingDocument, String> {
@Query("{\"knn\":{\"field\":\"embedding\",\"query_vector\":?0,\"k\":5,\"num_candidates\":100}}")
List<EmbeddingDocument> findSimilarDocuments(List<Double> queryVector);
}
In EmbeddingService, call the findSimilarDocuments method of the EmbeddingRepository.
package com.azure.ai.openai.usage.elasticsearch;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
public class EmbeddingService {
@Autowired
private EmbeddingRepository embeddingRepository;
public EmbeddingDto storeEmbedding(EmbeddingDto embeddingDto) {
if (embeddingDto.getId() == null || embeddingDto.getId().isEmpty()) {
embeddingDto.setId(UUID.randomUUID().toString());
}
EmbeddingDocument document = mapToDocument(embeddingDto);
EmbeddingDocument savedDocument = embeddingRepository.save(document);
return mapToDto(savedDocument);
}
public List<EmbeddingDto> findSimilarDocuments(List<Double> queryVector) {
List<EmbeddingDocument> similarDocuments = embeddingRepository.findSimilarDocuments(queryVector);
return similarDocuments.stream().map(this::mapToDto).collect(Collectors.toList());
}
private EmbeddingDocument mapToDocument(EmbeddingDto dto) {
return new EmbeddingDocument(dto.getId(), dto.getText(), dto.getEmbedding());
}
private EmbeddingDto mapToDto(EmbeddingDocument document) {
return new EmbeddingDto(document.getId(), document.getText(), document.getEmbedding());
}
}
Finally in the controller class expose a POST REST endpoint with /enhanced-answer url as follows-
package com.azure.ai.openai.usage;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatMessage;
import com.azure.ai.openai.models.ChatRole;
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
import com.azure.ai.openai.usage.elasticsearch.EmbeddingDto;
import com.azure.ai.openai.usage.elasticsearch.EmbeddingService;
import com.azure.core.credential.AzureKeyCredential;
@RestController
public class PromptController {
@Value("${azure.openai.api.key}")
private String azureOpenaiKey;
@Value("${azure.openai.endpoint}")
private String endpoint;
@Value("${azure.openai.deployment.model.id}")
private String deploymentOrModelId;
@Value("${azure.openai.embedding.model.id:deploymentOrModelId}")
private String embeddingModelId;
@Autowired
private EmbeddingService embeddingService;
@PostMapping("/answer")
public List<String> getMethodName(@RequestBody PromptQuestion promptQuestion) {
List<String> responseList = new ArrayList<>();
try {
OpenAIClient client = new OpenAIClientBuilder().endpoint(endpoint)
.credential(new AzureKeyCredential(azureOpenaiKey)).buildClient();
List<ChatMessage> messages = new ArrayList<>();
messages.add(new ChatMessage(ChatRole.SYSTEM)
.setContent("You are an AI assistant that helps people find information."));
messages.add(new ChatMessage(ChatRole.USER).setContent(getPrompt(promptQuestion)));
ChatCompletionsOptions options = new ChatCompletionsOptions(messages).setTemperature(0.7).setTopP(0.95)
.setMaxTokens(800);
ChatCompletions completions = client.getChatCompletions(deploymentOrModelId, options);
for (ChatChoice choice : completions.getChoices()) {
responseList.add(choice.getMessage().getContent().trim());
}
} catch (Exception ex) {
ex.printStackTrace();
responseList.add("Exception Occurred");
}
return responseList;
}
@PostMapping("/get-embed")
public List<Double> getEmbedding(@RequestBody EmbedData embedData) {
try {
OpenAIClient client = new OpenAIClientBuilder().endpoint(endpoint)
.credential(new AzureKeyCredential(azureOpenaiKey)).buildClient();
String input = getEmbedData(embedData);
Embeddings embeddings = client.getEmbeddings(embeddingModelId,
new com.azure.ai.openai.models.EmbeddingsOptions(List.of(input)));
EmbeddingItem embeddingItem = embeddings.getData().get(0);
List<Double> embedding = embeddingItem.getEmbedding();
return embedding;
} catch (Exception ex) {
ex.printStackTrace();
return List.of();
}
}
@PostMapping("/store-embed")
public ResponseEntity<EmbeddingDto> storeEmbedding(@RequestBody EmbedData embedData) {
try {
String text = getEmbedData(embedData);
List<Double> embedding = getEmbedding(embedData);
if (embedding.isEmpty()) {
return new ResponseEntity<>(HttpStatus.INTERNAL_SERVER_ERROR);
}
EmbeddingDto embeddingDto = new EmbeddingDto();
embeddingDto.setId(UUID.randomUUID().toString());
embeddingDto.setText(text);
embeddingDto.setEmbedding(embedding);
EmbeddingDto savedEmbedding = embeddingService.storeEmbedding(embeddingDto);
return new ResponseEntity<>(savedEmbedding, HttpStatus.CREATED);
} catch (Exception ex) {
ex.printStackTrace();
return new ResponseEntity<>(HttpStatus.INTERNAL_SERVER_ERROR);
}
}
@PostMapping("/enhanced-answer")
public ResponseEntity<List<String>> getEnhancedAnswer(@RequestBody PromptQuestion promptQuestion) {
List<String> responseList = new ArrayList<>();
try {
// First, search for similar documents
EmbedData data = new EmbedData();
data.setData(getPrompt(promptQuestion));
List<Double> queryEmbedding = getEmbedding(data);
List<EmbeddingDto> similarDocuments = embeddingService.findSimilarDocuments(queryEmbedding);
// Extract the most relevant contexts (top 3)
StringBuilder context = new StringBuilder();
int maxDocs = Math.min(3, similarDocuments.size());
for (int i = 0; i < maxDocs; i++) {
context.append("Context ").append(i + 1).append(": ").append(similarDocuments.get(i).getText())
.append("\n\n");
}
// Create a prompt that includes this context
String enhancedPrompt = "Based on the following contexts, please answer the user's question.\n\n"
+ context.toString() + "User's question: " + getPrompt(promptQuestion);
// Call OpenAI with the enhanced prompt
OpenAIClient client = new OpenAIClientBuilder().endpoint(endpoint)
.credential(new AzureKeyCredential(azureOpenaiKey)).buildClient();
List<ChatMessage> messages = new ArrayList<>();
messages.add(new ChatMessage(ChatRole.SYSTEM)
.setContent("You are an AI assistant that helps people find information. "
+ "Use the provided contexts to answer questions. If the contexts don't contain "
+ "relevant information, just say you don't know."));
messages.add(new ChatMessage(ChatRole.USER).setContent(enhancedPrompt));
ChatCompletionsOptions options = new ChatCompletionsOptions(messages).setTemperature(0.7).setTopP(0.95)
.setMaxTokens(800);
ChatCompletions completions = client.getChatCompletions(deploymentOrModelId, options);
for (ChatChoice choice : completions.getChoices()) {
responseList.add(choice.getMessage().getContent().trim());
}
return new ResponseEntity<>(responseList, HttpStatus.OK);
} catch (Exception ex) {
ex.printStackTrace();
responseList.add("Exception Occurred");
return new ResponseEntity<>(responseList, HttpStatus.INTERNAL_SERVER_ERROR);
}
}
private String getPrompt(PromptQuestion promptQuestion) {
String input = promptQuestion.getQuestion().trim();
return input;
}
private String getEmbedData(EmbedData embedData) {
String input = embedData.getData().trim();
return input;
}
}
Download Source Code
Download it -
Spring Boot AI + Azure OpenAI + RAG Hello World Example