Part 2: Building a Self-Correcting RAG with Conditional Edges
Part 2: Building a Self-Correcting RAG with Conditional Edges
Why Settle for One-Shot? Towards a Cognitively Inspired RAG
My fascination lies in the intricate machinery of the human brain. How do we learn? How do we adapt? Most critically, how do we self-correct? We don't just process information once; we reflect, we re-evaluate, we retry. Traditional RAG pipelines, while powerful, often feel like a single neuron firing – a one-shot process that retrieves, generates, and stops. This isn't how intelligence works.
For me, the goal isn't just to build an AI, but to understand and replicate cognitive functions. Imagine a system that, when faced with an inadequate answer, doesn't just give up, but introspects, refines its approach, and tries again. This iterative, reflective process is crucial for robust AI and aligns perfectly with how specialized "expert" modules in an MoE (Mixture of Experts) architecture might collaborate and self-regulate, much like different regions of the brain.
This post, Part 2 of my exploration, dives into building a RAG pipeline that isn't just a linear chain, but a dynamic, self-correcting graph. We'll move beyond simple sequential steps to incorporate feedback loops, allowing our agent to reflect on its output and reroute its execution based on internal state. My ethos remains: performance, directness, and minimal abstraction. While I'll use a graph library to demonstrate the concept of conditional edges, understand that the underlying mechanics are paramount, and for high-performance production systems, I'd often favor custom, raw API implementations over bloated frameworks.
Architecture: The Self-Correcting Loop
The core idea is simple yet profound: stateful execution driven by conditional logic. Our agent will maintain a shared state, execute specific functions (nodes) based on this state, and then dynamically decide its next step using conditional_edges.
1. Defining the Graph State
First, we need to define what information our RAG agent needs to carry through its execution. A TypedDict is perfect for this, offering clarity and type safety without unnecessary overhead.
from typing import List, Dict, TypedDict, Literal, Optional
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
# Minimal LangGraph imports for conceptual clarity
from langgraph.graph import StateGraph, END
# For competitive programmer-style efficiency and raw API preference:
import os
import time
# --- Mock/Setup (replace with your actual integrations) ---
# For LLM calls, Agno (or a custom httpx client) would be my preference for speed.
# Here, using langchain_openai for simplicity in example, but note the intent.
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY", "sk-YOUR-OPENAI-API-KEY")
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) # Fast, cheap for iteration
# Mock Retriever: In reality, this connects to a real vector DB (e.g., Chroma, Qdrant)
# For competitive programming, I'd optimize embedding lookups and vector storage.
class MockRetriever:
def __init__(self, documents: List[str]):
# This would be a full-fledged vector store in production
# For a quick demo, a simple text match or pre-embedded docs would work.
self.documents = documents
print("MockRetriever initialized with a simple text store.")
def get_relevant_documents(self, query: str) -> List[Document]:
# Simulate retrieval time for performance analysis
time.sleep(0.1)
# Simple keyword matching for demo. Real retriever uses vector similarity.
relevant = [
Document(page_content=doc)
for doc in self.documents
if query.lower() in doc.lower()
]
if not relevant:
# Fallback for better demo
relevant = [Document(page_content=self.documents[i]) for i in range(min(2, len(self.documents)))]
print(f"Retrieved {len(relevant)} docs for query: '{query}'")
return relevant
mock_docs = [
"The quick brown fox jumps over the lazy dog.",
"Python is a high-level, interpreted programming language, valued for its readability.",
"LangGraph allows building stateful, multi-actor applications with conditional logic.",
"Conditional edges in LangGraph enable dynamic routing based on state changes.",
"The human brain is an incredibly complex organ, capable of reflection and self-correction.",
"Mixture of Experts (MoE) architectures can mimic specialized brain regions, improving scalability.",
"Competitive programming emphasizes efficient algorithms and data structures for optimal performance.",
"Agno is a lightweight Python client for LLMs, prioritizing raw performance and minimal overhead.",
"Neural networks learn through backpropagation, adjusting weights to minimize error.",
"RAG pipelines combine retrieval with generation for grounded, accurate answers, reducing hallucinations."
]
retriever = MockRetriever(mock_docs)
class GraphState(TypedDict):
"""
Represents the state of our graph.
- `query`: The initial or current query.
- `documents`: List of retrieved documents.
- `answer`: The generated answer.
- `reflection`: The result of the reflection step (e.g., "satisfactory", "needs_retrieval", "fail").
- `attempts`: Counter for retry attempts.
- `error`: Any error message encountered.
"""
query: str
documents: List[Document]
answer: Optional[str]
reflection: Optional[Literal["satisfactory", "needs_retrieval", "needs_rewrite", "fail"]]
attempts: int
error: Optional[str]
2. Defining the Nodes (The "Experts")
Each node in our graph performs a specific task, taking the current GraphState and returning an updated version. These are our "experts" – specialized functions that contribute to the overall cognitive process.
# Node 1: Retrieval
def retrieve(state: GraphState) -> GraphState:
"""
Retrieves documents based on the current query.
"""
start_time = time.perf_counter()
print(f"---NODE: RETRIEVE (Attempt {state['attempts'] + 1})---")
query = state["query"]
documents = retriever.get_relevant_documents(query)
end_time = time.perf_counter()
print(f"Retrieval took {end_time - start_time:.4f} seconds.")
return {**state, "documents": documents, "attempts": state["attempts"] + 1}
# Node 2: Generation
def generate(state: GraphState) -> GraphState:
"""
Generates an answer using the retrieved documents and the query.
"""
start_time = time.perf_counter()
print("---NODE: GENERATE---")
query = state["query"]
documents = state["documents"]
if not documents:
print("No documents found for generation. Setting reflection to 'fail'.")
return {**state, "answer": None, "reflection": "fail"}
# Use a lean prompt for directness and performance
prompt_template = ChatPromptTemplate.from_messages(
[
("system",
"You are a helpful assistant. Provide a concise answer based *only* on the following context. "
"If the answer is not in the context, state that clearly and do not make up information.\n\n"
"Context:\n{context}"),
("user", "{query}")
]
)
rag_chain = prompt_template | llm | StrOutputParser()
context = "\n".join([doc.page_content for doc in documents])
answer = rag_chain.invoke({"context": context, "query": query})
end_time = time.perf_counter()
print(f"Generation took {end_time - start_time:.4f} seconds.")
print(f"Generated Answer: {answer}")
return {**state, "answer": answer}
# Node 3: Reflection (The "Self-Correction" Brain)
def reflect(state: GraphState) -> GraphState:
"""
Reflects on the generated answer and determines if it's satisfactory
or if further action (e.g., re-retrieval, query rewrite) is needed.
"""
start_time = time.perf_counter()
print("---NODE: REFLECT---")
query = state["query"]
documents = state["documents"]
answer = state["answer"]
attempts = state["attempts"]
if answer is None: # This should ideally be caught earlier, but good to have
print("No answer to reflect upon. Marking as 'fail'.")
return {**state, "reflection": "fail"}
# We use the LLM to act as our 'critic' or 'internal monitor'.
# This prompt is critical for the quality of self-correction.
reflection_prompt = ChatPromptTemplate.from_messages(
[
("system",
"You are an expert critic. Your task is to evaluate the provided answer based on the original query and context. "
"Be strict and objective. Your output should be ONE of the following keywords:\n"
"- 'satisfactory': The answer is directly relevant, comprehensive given the context, and addresses the query.\n"
"- 'needs_retrieval': The answer is partially relevant but seems to lack sufficient detail or context. More retrieval is needed.\n"
"- 'needs_rewrite': The answer is completely off-topic or misunderstands the query. The original query might need reformulation.\n"
"- 'fail': The answer explicitly states it cannot find information, or is completely unhelpful/hallucinatory.\n\n"
"Original Query: {query}\n"
"Retrieved Context:\n{context}\n"
"Generated Answer: {answer}\n\n"
"Evaluation (single keyword):"
),
("user", "Evaluate the above.")
]
)
reflection_chain = reflection_prompt | llm | StrOutputParser()
context_str = "\n".join([doc.page_content for doc in documents])
reflection_result = reflection_chain.invoke({
"query": query,
"context": context_str,
"answer": answer
}).strip().lower()
# Normalize reflection result to our Literal type
if "satisfactory" in reflection_result:
reflection_type: Literal["satisfactory", "needs_retrieval", "needs_rewrite", "fail"] = "satisfactory"
elif "needs_retrieval" in reflection_result:
reflection_type = "needs_retrieval"
elif "needs_rewrite" in reflection_result:
reflection_type = "needs_rewrite"
else:
reflection_type = "fail" # Default to fail if LLM gives garbage
print(f"Reflection result: {reflection_type}")
end_time = time.perf_counter()
print(f"Reflection took {end_time - start_time:.4f} seconds.")
return {**state, "reflection": reflection_type}
# Node 4: Query Rewriting (for iterative improvement)
def rewrite_query(state: GraphState) -> GraphState:
"""
Rewrites the query based on reflection feedback, aiming for better retrieval.
"""
start_time = time.perf_counter()
print("---NODE: REWRITE QUERY---")
original_query = state["query"]
reflection_result = state["reflection"]
answer = state["answer"]
if reflection_result == "needs_rewrite":
# Prompt the LLM to reformulate the query
rewrite_prompt = ChatPromptTemplate.from_messages(
[
("system",
"The previous retrieval and generation failed to provide a satisfactory answer for the query '{original_query}'. "
"The current answer was: '{answer}'. "
"Based on this, please reformulate the original query to be more specific or to explore different angles, "
"aiming for better retrieval. Provide *only* the new query."
),
("user", "Reformulate the query.")
]
)
new_query = (rewrite_prompt | llm | StrOutputParser()).invoke({
"original_query": original_query,
"answer": answer
})
print(f"Original Query: '{original_query}' -> Rewritten Query: '{new_query}'")
return {**state, "query": new_query, "answer": None, "documents": []} # Reset documents and answer for fresh cycle
else:
# If reflection was 'needs_retrieval' but not 'needs_rewrite', keep original query for deeper retrieval.
print(f"No query rewrite needed. Original query '{original_query}' will be used for re-retrieval.")
return {**state, "answer": None, "documents": []} # Reset documents and answer for fresh cycle
3. Defining the Edges (The Logic Gates)
Here's where the magic of conditional_edges comes in. This allows our graph to dynamically route its execution based on the state. It's like the brain deciding whether to re-read a passage, ask a follow-up question, or move on, based on its comprehension.
# --- Define the graph ---
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)
workflow.add_node("reflect", reflect)
workflow.add_node("rewrite_query", rewrite_query)
# Set entry point
workflow.set_entry_point("retrieve")
# Define edges
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", "reflect")
# Define the conditional logic for the 'reflect' node
# This function determines the next node based on reflection result and attempts.
MAX_ATTEMPTS = 3 # Hyperparameter: How many times can we retry?
def route_reflection(state: GraphState) -> str:
print(f"---ROUTING based on Reflection: {state['reflection']}, Attempts: {state['attempts']}---")
if state["reflection"] == "satisfactory":
return "end"
if state["attempts"] >= MAX_ATTEMPTS:
print(f"Max attempts ({MAX_ATTEMPTS}) reached. Ending with failure.")
return "fail" # Special state to indicate failure or exit point
if state["reflection"] == "needs_rewrite":
return "rewrite_query"
elif state["reflection"] == "needs_retrieval":
return "retrieve" # Go back to retrieval with same or potentially deeper query
else: # e.g., "fail" or unexpected reflection
return "fail"
# Add conditional edges from the reflect node
workflow.add_conditional_edges(
"reflect",
route_reflection,
{
"retrieve": "retrieve", # If needs retrieval, go back to retrieve
"rewrite_query": "rewrite_query", # If needs rewrite, go to rewrite_query
"end": END, # If satisfactory, end the graph
"fail": END # If failed after attempts, end the graph
}
)
# After rewriting the query, we always go back to retrieval
workflow.add_edge("rewrite_query", "retrieve")
# Compile the graph
app = workflow.compile()
# --- Run the graph ---
print("\n--- Running the Self-Correcting RAG ---")
query_1 = "Tell me about Agno and its purpose."
initial_state_1 = GraphState(query=query_1, documents=[], answer=None, reflection=None, attempts=0, error=None)
final_state_1 = app.invoke(initial_state_1)
print("\n--- Final State 1 ---")
print(f"Query: {final_state_1['query']}")
print(f"Final Answer: {final_state_1['answer']}")
print(f"Reflection: {final_state_1['reflection']}")
print(f"Attempts: {final_state_1['attempts']}")
print("\n--- Running a challenging query ---")
query_2 = "What are the common challenges in MoE architectures that require competitive programming solutions?"
initial_state_2 = GraphState(query=query_2, documents=[], answer=None, reflection=None, attempts=0, error=None)
final_state_2 = app.invoke(initial_state_2)
print("\n--- Final State 2 ---")
print(f"Query: {final_state_2['query']}")
print(f"Final Answer: {final_state_2['answer']}")
print(f"Reflection: {final_state_2['reflection']}")
print(f"Attempts: {final_state_2['attempts']}")What I Learned: Towards True AI Cognition
Building this self-correcting RAG agent, even with a minimal graph library, reinforced several crucial lessons for anyone pushing the boundaries of AI:
- Iterative Systems are King: Just like human thought, AI systems benefit immensely from iterative refinement. A single pass is rarely enough for complex tasks. Embracing loops and feedback mechanisms is a fundamental step towards more robust and intelligent agents. This is a direct parallel to how our brains constantly update their understanding based on new sensory input and internal reflection.
- Explicit State Management: The
GraphStateisn't just a convenience; it's the memory of our agent. Clearly defining and managing this state is critical. Without it, conditional logic would be impossible, and the agent couldn't learn or adapt within a single interaction. - The Power of
conditional_edges: This API, or the concept it represents (state-dependent routing), is a game-changer. It transforms a linear pipeline into a dynamic decision-making network. This is the simplest way to introduce rudimentary "cognition" – the ability to observe, evaluate, and choose a path. For me, this is a stepping stone towards modeling the complex neural pathways and decision networks in a biological brain. - Performance is Non-Negotiable: Each node in an iterative system adds latency. My competitive programming background screams for efficiency. Using lightweight LLM clients (like Agno, or raw
httpxcalls), optimizing retrieval queries, and minimizing prompt token usage are paramount. An iterative loop that takes too long is useless. - LLM as a "Cognitive Module": The
reflectandrewrite_querynodes highlight how an LLM can act as a specialized cognitive module – a critic, a re-phraser. This modularity is a key concept in MoE architectures and how different brain regions handle distinct tasks. The quality of these LLM prompts directly dictates the agent's ability to self-correct effectively.
This is just the beginning. My vision extends to more sophisticated reflection, integrating memory systems that evolve beyond a single interaction, and truly distributed MoE architectures where expert agents dynamically contribute and self-organize. For now, the ability to build agents that can think, rethink, and refine their answers autonomously is a powerful step towards building AI that doesn't just process, but truly understands and adapts. And I'll continue doing it with the leanest, most performant code possible.