Search

Semantic Cache To Improve RAG System

Learn the various steps involved in creating a semantic cache to improve RAG system, as well as some semantic cache implementation tips.
RAG system

Table of Content

Subscribe to latest Insights

By clicking "Subscribe", you are agreeing to the our Terms of Use and Privacy Policy.

Introduction

In this blog we will discuss the Retrieval Augmented Generation and typically RAG solution with the help of open source large language models, FAISS, and chromaDB (the vector embedding store databases).

The major difference in this RAG system is that we will integrate the cache system that will store the user queries and decide whether the response is generated with the help prompt from the vector database or the cache.

open source large language models

Understanding Semantic Cache

Semantic cache is a technique used to identify similar or identical queries and their respective responses, unlike the traditional caching system that relies on the exact matching criteria. Semantic cache allows us to recognize the semantic meaning of the queries and wording or structuring. For instance, a query like “What is the capital of India?” and “Tell me the name of the capital of India?” Both the queries convey same intent and should be treated as similar queries by the semantic cache system.

Implementing Semantic Cache in the RAG System

In a RAG system, the output is retrieved from vector databases such as ChromaDB and integrated with large language model responses. To implement semantic cache, we introduce a cache layer between users’ queries and their model responses. This allows us to bypass the database retrieval for similar queries.

The semantic cache system consists of the following steps and semantic cache implementation tips.

Step 1: Import and load all dependencies

				
					!pip install transformers
!pip install accelerate
!pip install sentence-transformers
!pip install xformers
!pip install chromadb
!pip install datasets
!pip install faiss-cpu
				
			

Import all the required libraries

				
					import numpy as np
import pandas as pd
from datasets import load_dataset
import faiss
from sentence_transformers import SentenceTransformer
import time
import json
from torch import cuda, torch
from transformers import AutoTokenizer, AutoModelForCausalLM
				
			

Step 2: Load the dataset according to your use case in my case, I have used the following datasets

				
					data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split="train")
data = data.to_pandas()
data["id"] = data.index
data.head(10)
MAX_ROWS = 15000
DOCUMENT = "Answer"
TOPIC = "qtype"
subset_data = data.head(MAX_ROWS)
				
			

Step 3: Querying the ChromaDB database. In ChromaDB, the database is stored in a collection. If this already exists, we need to remove it and create a collection function using the Chroma client.

				
					collection_name = "news_collection"
if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
    chroma_client.delete_collection(name=collection_name)
collection = chroma_client.create_collection(name=collection_name)
				
			

After that, we are ready to add data into vectorDB by using the collection.add function. In the document, we store the answer column of the datasets in the metadata value of q-type, and in id, we need a unique identifier.

				
					collection.add(
    documents=subset_data[DOCUMENT].tolist(),
    metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
    ids=[f"id{x}" for x in range(MAX_ROWS)],
)
				
			

Now define the function for querying the ChomaDB Database.

				
					def query_database(query_text, n_results=10):
    results = collection.query(query_texts=query_text, n_results=n_results)
    return results
				
			

Step 4: Configure vector database ChromaDB and need to tell the path where our vector database will be stored.

				
					chroma_client =chromadb.PersistentClient(path="db")
				
			

Step 5: Creating a Semantic cache system for RAG system performance enhancement.

For a semantic cache integration, we will use the FAISS library to store embedding in memory. It is quite similar to ChromaDB but without its persistence. For this, we will create a class with its own encoder and provide the necessary functions for users to perform queries.

In this class, we first query a cache implemented with FAISS that contains a previous petition. If the return result is above the specified threshold, it returns the cache content; otherwise, the result is obtained from a vector database, which is ChromaDB, and the cache is stored in a .json file.

Initialize the cache and encode the data by creating the init_cache method

				
					def init_cache():
    index = faiss.IndexFlatL2(768)
    if index.is_trained:
        print("Index trained")
    # Initialize Sentence Transformer model
    encoder = SentenceTransformer("all-mpnet-base-v2")
    return index, encoder
				
			

In the cache_retrieved function .json file is retrieved in case we use this file from the disk

				
					def retrieve_cache(json_file):
    try:
        with open(json_file, "r") as file:
            cache = json.load(file)
    except FileNotFoundError:
        cache = {"questions": [], "embeddings": [], "answers": [], "response_text": []}
    return cache
				
			

Now, here I declare the class with a method that will initialize the Faiss index with Euclidean distance instead of cosine similarity and ask or query from the vector database based on the specified threshold value.

				
					class semantic_cache:
    def __init__(self, json_file="cache_file.json", thresold=0.35):
        self.index, self.encoder = init_cache()
        self.euclidean_threshold = thresold
        self.json_file = json_file
        self.cache = retrieve_cache(self.json_file)
    def ask(self, question: str) -> str:
        start_time = time.time()
        try:
            embedding = self.encoder.encode([question])
            self.index.nprobe = 8
            D, I = self.index.search(embedding, 1)
            if D[0] >= 0:
                if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
                    row_id = int(I[0][0])
                    print("Answer recovered from Cache. ")
                    print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
                    print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")
                    print(f"response_text: " + self.cache["response_text"][row_id])
                    end_time = time.time()
                    elapsed_time = end_time - start_time
                    print(f"Time taken: {elapsed_time:.3f} seconds")
                    return self.cache["response_text"][row_id]
            answer = query_database([question], 1)
            response_text = answer["documents"][0][0]
            self.cache["questions"].append(question)
            self.cache["embeddings"].append(embedding[0].tolist())
            self.cache["answers"].append(answer)
            self.cache["response_text"].append(response_text)
            print("Answer recovered from ChromaDB. ")
            print(f"response_text: {response_text}")
            self.index.add(embedding)
            store_cache(self.json_file, self.cache)
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Time taken: {elapsed_time:.3f} seconds")
            return response_text
        except Exception as e:
            raise RuntimeError(f"Error during 'ask' method: {e}")
				
			

Now make an object of class and call the ask method for getting a response from ChromaDB or a cache.

				
					cache = semantic_cache("4cache.json")
results = cache.ask("How do vaccines work?")
Answer recovered from ChromaDB. 
results = cache.ask("Explain briefly what is a Sydenham chorea")
Answer recovered from ChromaDB. 
results = cache.ask("Briefly explain me what is a Sydenham chorea.")
Answer recovered from Cache.
question_def = "Write in 20 words what is a Sydenham chorea."
results = cache.ask(question_def)
Answer recovered from Cache.
				
			

Loading Model

Now we load our language model from hugging face for using the transformer library, create the extended prompt, send it to the model, and wait for their response.

				
					device = f"cuda:{cuda.current_device()}" if cuda.is_available() else "cpu"
model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16)
prompt_template = f"Relevant context: {results}\n\n The user's question: {question_def}"
prompt_template
input_ids = tokenizer(prompt_template, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=256)
print(tokenizer.decode(outputs[0]))
				
			

Conclusion

Semantic caching is a valuable RAG system optimization technique. It enables efficient handling of recurrent queries and improves RAG system efficiency or overall performance. This can reduce the data retrieval time between accessing ChromaDB and getting directly from cache by 50%. 

By using semantic cache for RAG system understanding and efficient indexing, the cashing system enhances scalability, reliability, and user experience, and semantic caching will play a crucial role in evolving in various domains. 

FAQs

Semantic cache enhances the RAG system by storing and retrieving similar user queries and reducing the database retrieval time of responses. As a result, it improves the consistency of responses.

There are several steps for implementing semantic cache: Initialize the semantic cache, train the index with an encoder, store cache data in a json file in disk, retrieve similar query by using Euclidean distance, and integrate semantic cache into RAG system

A semantic cache is important because RAG systems improve by reducing time intervals for similar queries and responses, reducing database retrieval, which speeds up the response time with consistent responses and makes them more scalable and reliable.

While choosing a semantic cache for a RAG system, consider factors like indexing methods such as FlatL2 and HNSW and similarity metrics such as Euclidean distance, scalability, and storage efficiency.