This Python code integrates a retrieval-augmented generation (RAG) process with the Mistral API for generating responses based on both conversation history and relevant documents. Here's an explanation of the different parts of the code:
1. Environment Setup¶
- Environment Variables: The
.env
file is loaded, and the Mistral API key is retrieved from environment variables usingos.getenv("MISTRAL_API_KEY")
.
2. Document Retrieval (retrieve_relevant_docs
)¶
Keyword-Based Search: The function
retrieve_relevant_docs()
takes a user query and a list of documents, then checks for the presence of query keywords in the documents (case-insensitive). If a document contains any of the keywords, it is added to the list of relevant documents.Example: If the query is
"What is the debt-to-equity ratio of ABC Corp?"
, the function searches for the terms"debt"
,"equity"
, etc., in the available documents.
3. Mistral API Call with RAG (make_mistral_api_call_with_rag
)¶
Combining Relevant Documents: The retrieved relevant documents are concatenated into a single string of context.
Augmenting the User Message: The user’s message is combined with the context from the relevant documents to form the augmented message. This augmented message provides background information for the Mistral API to generate a more contextually aware response.
Conversation History: The augmented message is appended to the conversation history, which is passed to the Mistral API in a list format (similar to a chat history).
API Call: The function sends a POST request to the Mistral API endpoint (
/v1/chat/completions
), including the model, conversation history, and relevant settings liketemperature
andmax_tokens
.
4. Formatting Markdown (format_markdown
)¶
Markdown to HTML: This function converts Markdown-style text into HTML for display purposes. It handles basic Markdown formatting, such as bold (**
**text**
**), italic (**text*
*), headings (#
,##
,###
), and code blocks (```python
).Line Breaks: It also replaces newlines (
\n
) with HTML line breaks (<br>
) to ensure proper rendering.
5. Displaying the API Response (display_api_response
)¶
Error Handling: If the API response contains an error, it prints the error message and stops further processing.
Extracting Key Data: It extracts information from the API response, including:
- Response Content: The assistant’s generated response.
- Token Usage: The number of tokens used in the prompt, completion, and overall.
- Model Info: The model used for the response.
- Metadata: Additional details like the response ID, object, and finish reason.
HTML Display: It formats the response as HTML and displays it using the Jupyter
IPython.display.HTML
function.
6. Example Workflow¶
- User Query: The user asks, "What is the debt-to-equity ratio of ABC Corp?".
- Documents: The documents list contains financial details about ABC Corp., including information about its debt-to-equity ratio.
- Process:
- Retrieve Relevant Documents: The function identifies the document that mentions the debt-to-equity ratio ("ABC Corp. has a current debt-to-equity ratio of 0.3").
- Augment the Query: The retrieved document is added to the query as context.
- Send API Request: The augmented query is sent to the Mistral API.
- Display the Response: The API response, including the assistant's generated message, token usage, and metadata, is displayed.
Summary of Workflow:¶
- User Input: The user provides a query.
- Document Retrieval: Relevant documents are retrieved based on keywords from the query.
- Augment Query: The user's query is augmented with relevant document context.
- API Call: The conversation history, including the augmented query, is sent to Mistral for response generation.
- Response Display: The API response is formatted and displayed as HTML, showing both the content and key metadata.
This implementation enhances the assistant's ability to respond based on both the conversation and external documents, providing a more informed answer.
import os
import requests
from dotenv import load_dotenv
from IPython.display import display, HTML
import re
# Load environment variables from the .env file
load_dotenv()
# Get the API key from environment variables
api_key = os.getenv("MISTRAL_API_KEY")
def retrieve_relevant_docs(query, documents):
"""
Simple keyword-based function to retrieve relevant documents.
"""
relevant_docs = []
for doc in documents:
if any(keyword.lower() in doc.lower() for keyword in query.split()):
relevant_docs.append(doc)
return relevant_docs
def make_mistral_api_call_with_rag(conversation_history, current_message, documents, model=None):
"""
Makes a call to the Mistral API using the provided conversation history
and current message with retrieved relevant documents.
:param conversation_history: List of conversation history messages
:param current_message: Current user message
:param documents: List of available documents for retrieval
:param model: Mistral model to use (default is "open-mistral-7b")
:return: JSON response from the Mistral API
"""
# Use "open-mistral-7b" as the default model
if model is None:
model = "open-mistral-7b"
# Retrieve relevant documents based on the current message
relevant_docs = retrieve_relevant_docs(current_message, documents)
# Combine the retrieved documents into a single context
context = "\n\n".join(relevant_docs)
# Add the retrieved context to the current message
augmented_message = f"Context: {context}\n\n{current_message}"
# Add the augmented current message to the conversation history
conversation_history.append({"role": "user", "content": augmented_message})
url = 'https://api.mistral.ai/v1/chat/completions'
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {api_key}'
}
data = {
"model": model,
"messages": conversation_history
}
response = requests.post(url, headers=headers, json=data)
return response.json()
def format_markdown(content):
"""
Converts Markdown content to HTML.
:param content: Markdown text
:return: HTML text
"""
# Remove unnecessary line breaks after enumerations
content = re.sub(r'(\d+\..*?)\n\n', r'\1\n', content)
# Convert Markdown to HTML
content = content.replace('\n', '<br>')
content = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', content)
content = re.sub(r'\*(.*?)\*', r'<em>\1</em>', content)
content = re.sub(r'### (.*)', r'<h3>\1</h3>', content)
content = re.sub(r'## (.*)', r'<h2>\1</h2>', content)
content = re.sub(r'# (.*)', r'<h1>\1</h1>', content)
content = re.sub(r'```python\n(.*?)\n```', r'<pre><code>\1</code></pre>', content, flags=re.DOTALL)
return content
def display_api_response(response):
"""
Formats the JSON response from the Mistral API for HTML display.
:param response: JSON response from the Mistral API
:return: None
"""
# Check if the response contains errors
if 'error' in response:
print(f"Error: {response['error']['message']}")
return
# Extract data from the response
choice = response.get('choices', [{}])[0]
message = choice.get('message', {})
role = message.get('role', 'N/A')
content = message.get('content', 'No content available.')
usage = response.get('usage', {})
# Format the content with Markdown
formatted_content = format_markdown(content)
html = """
<div class="api-response">
"""
# Model Information
html += f"""
<div class="bubble">
<h3>Model Information</h3>
<p><strong>Model:</strong> {response.get('model', 'N/A')}</p>
</div>
"""
# Token Usage
html += f"""
<div class="bubble">
<h3>Token Usage</h3>
<p><strong>Prompt Tokens:</strong> {usage.get('prompt_tokens', 'N/A')}</p>
<p><strong>Completion Tokens:</strong> {usage.get('completion_tokens', 'N/A')}</p>
<p><strong>Total Tokens:</strong> {usage.get('total_tokens', 'N/A')}</p>
</div>
"""
# Response Content
html += f"""
<div class="bubble">
<h3>Response Content</h3>
<p><strong>Role:</strong> {role}</p>
<p><strong>Content:</strong></p>
<div>{formatted_content}</div>
</div>
"""
# Additional Metadata
html += f"""
<div class="bubble">
<h3>Additional Metadata</h3>
<p><strong>ID:</strong> {response.get('id', 'N/A')}</p>
<p><strong>Object:</strong> {response.get('object', 'N/A')}</p>
<p><strong>Created:</strong> {response.get('created', 'N/A')}</p>
<p><strong>Finish Reason:</strong> {choice.get('finish_reason', 'N/A')}</p>
</div>
"""
html += "</div>"
display(HTML(html))
conversation_history = []
current_message = "What is the debt-to-equity ratio of ABC Corp?"
documents = [
"ABC Corp. reported a revenue of 50 million for Q2 2024, a 10 percent increase from Q1 2024. The company's net income for the quarter was 5 million, reflecting a 5 percent profit margin.",
"ABC Corp. has a current debt-to-equity ratio of 0.3, indicating that the company has a low level of debt compared to its equity.",
"The market capitalization of ABC Corp. is currently 300 million, based on a share price of 30 and 10 million shares outstanding.",
"In Q2 2024, ABC Corp. announced a dividend of 0.50 per share, which will be distributed to shareholders on October 1, 2024.",
"ABC Corp.'s gross profit margin for Q2 2024 was 40 percent, reflecting strong control over cost of goods sold and efficient operations."
]
response = make_mistral_api_call_with_rag(conversation_history, current_message, documents,model="open-mistral-7b")
display_api_response(response)
response = make_mistral_api_call_with_rag(conversation_history, current_message, documents,model="open-mixtral-8x7b")
display_api_response(response)
response = make_mistral_api_call_with_rag(conversation_history, current_message, documents,model="open-mixtral-8x22b")
display_api_response(response)
response = make_mistral_api_call_with_rag(conversation_history, current_message, documents,model="mistral-small-latest")
display_api_response(response)
response = make_mistral_api_call_with_rag(conversation_history, current_message, documents,model="mistral-medium-latest")
display_api_response(response)