COHERE - RAG
Simon-Pierre Boucher
2024-09-14
This Python script interacts with the Cohere API to generate responses by incorporating retrieval-augmented generation (RAG). The process uses relevant documents to augment the user query, helping to produce more accurate and contextually aware responses. Here’s a detailed breakdown:
1. Environment Setup¶
- Load API Key: The script uses
load_dotenv()
to load environment variables from a.env
file, specifically retrieving the Cohere API key (COHERE_API_KEY
) usingos.getenv()
.
2. Document Retrieval (retrieve_relevant_docs
)¶
- Keyword Matching: This function searches for relevant documents based on keywords in the user query. The documents that contain any of the query’s keywords are identified as relevant and added to a list.
- Example: If the query is
"What is the debt-to-equity ratio of ABC Corp?"
, the function searches through the list of documents for terms like"debt"
,"equity"
, and"ABC Corp."
.
3. Cohere API Call with RAG (make_cohere_api_call_with_rag
)¶
- Retrieving Documents: Based on the user’s query, the function retrieves relevant documents and combines them into a single context.
- Augmenting the User Query: The context is prepended to the user’s query to form an augmented message. This provides additional context to the model for generating a more accurate response.
- Conversation History: The conversation history, including the augmented query, is formatted into Cohere’s expected role-based structure (
USER
for user messages andCHATBOT
for responses). - API Call: The conversation history (minus the last user message) and the augmented query are sent to Cohere’s chat endpoint (
/v1/chat
) via a POST request. The model used is"command-r"
by default, but it can be customized.
4. Markdown to HTML Conversion (format_markdown
)¶
- Format Conversion: The
format_markdown()
function converts Markdown formatting (e.g., bold, italics, headers, code blocks) into HTML. This ensures that the response content displays correctly when rendered in a Jupyter notebook or HTML page.
5. Displaying the API Response (display_api_response
)¶
- Error Handling: If the API response contains an error, the error message is printed.
- Response Extraction: The chatbot response text is extracted along with token usage information (input and output tokens) and additional metadata like chat history and billing units.
- HTML Formatting: The response is formatted as HTML using styles for easy readability and then displayed using
IPython.display.HTML()
. The displayed information includes:- Token Usage: A breakdown of input and output tokens used by the API.
- Response Content: The assistant’s generated message, rendered as HTML.
6. Example Workflow¶
- User Query: The user asks, "What is the debt-to-equity ratio of ABC Corp?".
- Document Retrieval: The
retrieve_relevant_docs()
function identifies the document containing the relevant information: "ABC Corp. has a current debt-to-equity ratio of 0.3." - Augmentation: This relevant document is added as context to the query.
- API Request: The augmented query and conversation history are sent to the Cohere API.
- Response: The chatbot generates a response, which is displayed in HTML format.
Summary of Workflow:¶
- User Input: The user provides a query.
- Document Retrieval: The system retrieves relevant documents based on the query.
- Augmentation: The user's query is combined with relevant document context.
- API Call: The conversation history, along with the augmented query, is sent to Cohere’s chat API.
- Display: The response from the API is formatted and displayed.
This setup allows for more informed and contextually aware responses by integrating external documents with the conversational input, enhancing the capabilities of a typical chatbot.
In [2]:
import os
import requests
from dotenv import load_dotenv
from IPython.display import display, HTML
import re
# Load environment variables from the .env file
load_dotenv()
# Get the API key from environment variables
api_key = os.getenv("COHERE_API_KEY")
def retrieve_relevant_docs(query, documents):
"""
Simple keyword-based function to retrieve relevant documents.
"""
relevant_docs = []
for doc in documents:
if any(keyword.lower() in doc.lower() for keyword in query.split()):
relevant_docs.append(doc)
return relevant_docs
def make_cohere_api_call_with_rag(conversation_history, current_message, documents, model="command-r"):
"""
Makes an API call to Cohere using the provided conversation history and the current message
with retrieved relevant documents.
:param conversation_history: List of messages from the conversation history
:param current_message: Current user's message
:param documents: List of available documents for retrieval
:param model: Cohere model to use (default is "command-r")
:return: JSON response from the Cohere API
"""
# Retrieve relevant documents based on the current message
relevant_docs = retrieve_relevant_docs(current_message, documents)
# Combine retrieved documents into a single context
context = "\n\n".join(relevant_docs)
# Add the retrieved context to the current message
augmented_message = f"Context: {context}\n\n{current_message}"
# Convert role format for the Cohere API
cohere_history = []
for message in conversation_history:
cohere_history.append({
"role": "USER" if message["role"] == "user" else "CHATBOT",
"message": message["content"]
})
# Add the augmented current message as a "user" message
cohere_history.append({"role": "USER", "message": augmented_message})
url = 'https://api.cohere.com/v1/chat'
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}',
'accept': 'application/json'
}
data = {
"chat_history": cohere_history[:-1], # Complete history without the last message
"message": augmented_message, # Last message as the current message
"model": model # Specify the model to use
}
response = requests.post(url, headers=headers, json=data)
return response.json()
def format_markdown(content):
# Remove unnecessary line breaks after enumerations
content = re.sub(r'(\d+\..*?)\n\n', r'\1\n', content)
# Convert Markdown to HTML
content = content.replace('\n', '<br>')
content = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', content)
content = re.sub(r'\*(.*?)\*', r'<em>\1</em>', content)
content = re.sub(r'### (.*)', r'<h3>\1</h3>', content)
content = re.sub(r'## (.*)', r'<h2>\1</h2>', content)
content = re.sub(r'# (.*)', r'<h1>\1</h1>', content)
content = re.sub(r'```python\n(.*?)\n```', r'<pre><code>\1</code></pre>', content, flags=re.DOTALL)
return content
def display_api_response(response):
"""
Formats the JSON response from the Cohere API for HTML display.
:param response: JSON response from the Cohere API
:return: None
"""
# Check if the response contains errors
if 'error' in response:
print(f"Error: {response['error']['message']}")
return
# Extract data from the response
text = response.get('text', 'No content available.')
chat_history = response.get('chat_history', [])
meta = response.get('meta', {})
billed_units = meta.get('billed_units', {})
# Format the content with Markdown
formatted_content = format_markdown(text)
html = """
<style>
.api-response {
font-family: Arial, sans-serif;
margin: 20px;
}
.bubble {
padding: 15px;
border-radius: 15px;
margin-bottom: 10px;
}
</style>
<div class="api-response">
"""
# Token usage information
html += f"""
<div class="bubble">
<h3>Token Usage</h3>
<p><strong>Input Tokens:</strong> {billed_units.get('input_tokens', 'N/A')}</p>
<p><strong>Output Tokens:</strong> {billed_units.get('output_tokens', 'N/A')}</p>
</div>
"""
# Response content
html += f"""
<div class="bubble">
<h3>Response Content</h3>
<p><strong>Role:</strong> CHATBOT</p>
<p><strong>Content:</strong></p>
<div>{formatted_content}</div>
</div>
"""
html += "</div>"
display(HTML(html))
In [3]:
conversation_history = []
current_message = "What is the debt-to-equity ratio of ABC Corp?"
documents = [
"ABC Corp. reported a revenue of 50 million for Q2 2024, a 10 percent increase from Q1 2024. The company's net income for the quarter was 5 million, reflecting a 5 percent profit margin.",
"ABC Corp. has a current debt-to-equity ratio of 0.3, indicating that the company has a low level of debt compared to its equity.",
"The market capitalization of ABC Corp. is currently 300 million, based on a share price of 30 and 10 million shares outstanding.",
"In Q2 2024, ABC Corp. announced a dividend of 0.50 per share, which will be distributed to shareholders on October 1, 2024.",
"ABC Corp.'s gross profit margin for Q2 2024 was 40 percent, reflecting strong control over cost of goods sold and efficient operations."
]
In [4]:
response = make_cohere_api_call_with_rag(conversation_history, current_message, documents, model="command-r")
display_api_response(response)
In [5]:
response = make_cohere_api_call_with_rag(conversation_history, current_message, documents, model="command-r-plus")
display_api_response(response)
In [6]:
response = make_cohere_api_call_with_rag(conversation_history, current_message, documents, model="command-nightly")
display_api_response(response)
In [7]:
response = make_cohere_api_call_with_rag(conversation_history, current_message, documents, model="command-light")
display_api_response(response)