mirror of
https://github.com/ION606/ML-pipeline.git
synced 2026-05-14 21:06:54 +00:00
added rudamentary search and code execution
This commit is contained in:
@@ -0,0 +1,224 @@
|
||||
from codeExecution import UserEnvironment, orchestrate_code
|
||||
from queries import (
|
||||
perform_web_search,
|
||||
rag_query,
|
||||
classify_task,
|
||||
MODEL_NAMES
|
||||
)
|
||||
import debug as debugMod
|
||||
from search import perform_web_search
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
import os
|
||||
import argparse
|
||||
import re
|
||||
import ollama
|
||||
from config import Config
|
||||
import conversation_store
|
||||
conversation_store.initialize_db()
|
||||
|
||||
ollama.Client(host='http://ollama:11434')
|
||||
|
||||
# Just in case
|
||||
for complexity, model_name in MODEL_NAMES.items():
|
||||
print(f"Pulling {complexity} model ({model_name})...")
|
||||
|
||||
# Stream the pull process and print progress
|
||||
for progress in ollama.pull(model=model_name, stream=True):
|
||||
if "status" in progress:
|
||||
print(f" {progress['status']}", end="\r") # Overwrite the same line
|
||||
if "completed" in progress and "total" in progress:
|
||||
# Calculate and print download percentage
|
||||
percent = (progress["completed"] / progress["total"]) * 100
|
||||
print(f" Downloading: {percent:.1f}% complete", end="\r")
|
||||
|
||||
print("\nDone!") # Newline after each model is pulled
|
||||
|
||||
|
||||
def load_and_chunk_file(file_path):
|
||||
debugMod.log(f"Loading and chunking file: {file_path}")
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File {file_path} not found")
|
||||
|
||||
with open(file_path, "r") as f:
|
||||
text = f.read()
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=Config.CHUNK_SIZE, chunk_overlap=Config.CHUNK_OVERLAP)
|
||||
chunks = splitter.split_text(text)
|
||||
debugMod.log(f"File chunked into {len(chunks)} chunks")
|
||||
return chunks
|
||||
|
||||
|
||||
def create_vector_store(chunks):
|
||||
debugMod.log("Creating vector store")
|
||||
if not chunks:
|
||||
debugMod.log("No chunks provided, returning None")
|
||||
return None
|
||||
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
||||
|
||||
vector_store = Chroma.from_texts(
|
||||
chunks,
|
||||
embeddings,
|
||||
persist_directory=Config.chroma_path()
|
||||
)
|
||||
|
||||
debugMod.log("Vector store created")
|
||||
return vector_store
|
||||
|
||||
|
||||
def orchestrate(query, vector_store=None, comm_outp=print, comm_inp=input):
|
||||
debugMod.log(f"Orchestrating query: {query}")
|
||||
aggregated_web_context = ""
|
||||
local_context = ""
|
||||
user_context = ""
|
||||
response_context = ""
|
||||
links = []
|
||||
|
||||
# Classify task once at start
|
||||
task_type = classify_task(query)
|
||||
debugMod.log(f"Task classified as: {task_type}")
|
||||
|
||||
# Early exit for simple tasks
|
||||
if task_type == "simple":
|
||||
debugMod.log("Direct response for simple task")
|
||||
return [rag_query(query, task_type=task_type), []]
|
||||
|
||||
# Initialize context for medium/complex tasks
|
||||
if vector_store:
|
||||
docs = vector_store.similarity_search(query, k=3)
|
||||
local_context = "\n".join(
|
||||
[d.page_content for d in docs]) if docs else ""
|
||||
debugMod.log(f"Local context: {local_context}")
|
||||
|
||||
iteration = 0
|
||||
status = "continue"
|
||||
|
||||
while iteration < Config.MAX_ORCHESTRATION_ITERATIONS and status != "final":
|
||||
debugMod.log(f"--- Iteration {iteration} [Status: {status}] ---")
|
||||
response = ""
|
||||
|
||||
if status == "continue":
|
||||
# Include previous responses in reflection
|
||||
reflection_prompt = f"""Determine the next action needed to answer: {query}
|
||||
|
||||
Available actions:
|
||||
1. web_search - Needs web information
|
||||
2. user_input - Requires clarification
|
||||
3. final_response - Ready to answer
|
||||
|
||||
Context:
|
||||
- Web: {aggregated_web_context}
|
||||
- Local: {local_context}
|
||||
- User: {user_context}
|
||||
- Previous Responses: {response_context}
|
||||
|
||||
Return ONLY: web_search/user_input/final_response"""
|
||||
|
||||
status = rag_query(
|
||||
reflection_prompt, task_type=task_type, silent=True).strip().lower()
|
||||
debugMod.log(f"Action determined: {status}")
|
||||
|
||||
if status == "web_search":
|
||||
search_prompt = f"""Generate search query considering: {query}
|
||||
Previous responses: {response_context}
|
||||
Return ONLY search terms"""
|
||||
|
||||
search_terms = rag_query(
|
||||
search_prompt, task_type=task_type, silent=True).strip('"')
|
||||
debugMod.log(f"Searching web for: {search_terms}")
|
||||
|
||||
web_results, new_links = perform_web_search(search_terms)
|
||||
links.extend(new_links)
|
||||
|
||||
if web_results:
|
||||
aggregated_web_context += f"\nWeb: {web_results}"
|
||||
debugMod.log(f"Updated web context")
|
||||
|
||||
elif status == "user_input":
|
||||
comm_outp("\n[System] Additional info needed:")
|
||||
user_input = comm_inp("Please clarify: ")
|
||||
user_context += f"\nUser input: {user_input}"
|
||||
debugMod.log(f"Received user input")
|
||||
status = "continue"
|
||||
|
||||
elif status == "final_response":
|
||||
break
|
||||
|
||||
else:
|
||||
debugMod.log(f"Unknown status: {status}")
|
||||
status = "final_response"
|
||||
|
||||
# Generate and store response
|
||||
if status != "final_response":
|
||||
response = rag_query(
|
||||
query,
|
||||
task_type=task_type,
|
||||
web_context=aggregated_web_context,
|
||||
local_context=local_context,
|
||||
user_context=user_context,
|
||||
response_context=response_context # Pass previous responses
|
||||
)
|
||||
response_context += f"\nIteration {iteration} response: {response}"
|
||||
debugMod.log(f"Iteration {iteration} response stored")
|
||||
|
||||
iteration += 1
|
||||
|
||||
# Generate final response with full context
|
||||
final_response = rag_query(
|
||||
f"Final answer considering: {query}",
|
||||
task_type=task_type,
|
||||
web_context=aggregated_web_context,
|
||||
local_context=local_context,
|
||||
user_context=user_context,
|
||||
response_context=response_context
|
||||
)
|
||||
|
||||
debugMod.log("Orchestration completed")
|
||||
return [final_response, links]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
debugMod.moveDebugLog()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--file', type=str, default="",
|
||||
help='Path to data file for analysis')
|
||||
parser.add_argument('--cli', type=str, default="false",
|
||||
help="whether to use the CLI for input or run the API")
|
||||
args = parser.parse_args()
|
||||
|
||||
vector_store = None
|
||||
chunks = []
|
||||
|
||||
if args.file:
|
||||
try:
|
||||
debugMod.log(f"Loading file: {args.file}")
|
||||
chunks = load_and_chunk_file(args.file)
|
||||
vector_store = create_vector_store(chunks)
|
||||
except Exception as e:
|
||||
debugMod.log(f"Error loading file: {str(e)}")
|
||||
chunks = []
|
||||
|
||||
user_env = UserEnvironment("ION606")
|
||||
|
||||
# if args.cli:
|
||||
while True:
|
||||
query = input("\nEnter your query (type 'exit' to quit): ")
|
||||
|
||||
if query.lower() == 'exit':
|
||||
break
|
||||
|
||||
[response, links] = orchestrate(query, vector_store)
|
||||
|
||||
if len(links) > 0:
|
||||
print(f"links: {", ".join(links)}")
|
||||
|
||||
# Save conversation to SQLite
|
||||
conversation_store.save_conversation(query, response, links)
|
||||
|
||||
# code
|
||||
code_blocks = re.findall(Config.code_block_regex(), response, re.DOTALL)
|
||||
if code_blocks:
|
||||
orchestrate_code(orchestrate, vector_store, chunks,
|
||||
user_env, code_blocks, query, response, links)
|
||||
Reference in New Issue
Block a user