diff --git a/.gitignore b/.gitignore index 15201ac..6456e86 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,10 @@ cython_debug/ # PyPI configuration file .pypirc + + +# custom +data/ +logs/ +*.csv +*.txt diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..f8e1440 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM python:3.12-slim + +WORKDIR /app + +# Install system dependencies (AppArmor) +RUN apt-get update && apt-get install -y apparmor + +COPY requirements.txt requirements.txt + +RUN pip install -r requirements.txt + +# Copy the entire project +COPY . . + +# Run the main application +CMD ["python", "main.py"] \ No newline at end of file diff --git a/Dockerfile.sandbox b/Dockerfile.sandbox new file mode 100644 index 0000000..e818250 --- /dev/null +++ b/Dockerfile.sandbox @@ -0,0 +1,17 @@ +FROM python:3.11-slim + +# Create non-root user +RUN adduser --disabled-password --gecos '' sandboxuser && \ + chmod 755 /home/sandboxuser + +# Set safe workspace +WORKDIR /sandbox +RUN chown sandboxuser:sandboxuser /sandbox + +# minimal dependencies....idk +RUN apt-get update && \ + apt-get install -y --no-install-recommends gcc python3-dev && \ + rm -rf /var/lib/apt/lists/* + +USER sandboxuser +ENTRYPOINT ["python", "/sandbox/script.py"] \ No newline at end of file diff --git a/README.md b/README.md index b4bd54e..2c84746 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,45 @@ # ML-pipeline -A simple pipeline to integrate searching, model iteration, and code correction with local ollama models -This is part of a work project I'm doing! Currently trying to get my bosses to OK open-sourcing everything \ No newline at end of file +A simple pipeline to integrate searching, model iteration, and code correction with local Ollama models. + +## Overview + +This project comprises several modules that work together to streamline the process of: +- Creating isolated virtual environments for safe code execution ([`UserEnvironment`](codeExecution.py)). +- Executing and orchestrating code via [`main.py`](main.py). +- Handling web search queries, model iterations, and task classification with local Ollama models using functions from [`queries.py`](queries.py) and [`search.py`](search.py). +- Managing conversation history in a local SQLite database via [`conversation_store.py`](conversation_store.py). + +## Project Structure + +- **codeExecution.py**: Implements the [`UserEnvironment`](codeExecution.py) class that creates a virtual environment for code execution with basic security measures. +- **main.py**: Serves as the entry point to the pipeline, orchestrating code execution and integrating search and model iterations. +- **queries.py**: Contains functions to perform web search, task classification, and other queries. +- **search.py**: Provides utility for performing web searches in the pipeline. +- **conversation_store.py**: Manages conversation persistence in a SQLite database under the `data/` folder. +- **debug.py**: Includes debug utilities for troubleshooting. + +## Installation + +1. Install the necessary dependencies via [requirements.txt](requirements.txt): + ```sh + pip install -r requirements.txt + ``` + +2. Ensure your Python version is compatible with the virtual environment setup (see [codeExecution.py](http://_vscodecontentref_/0)). + +## Usage + +Run the pipeline by executing the main script: +```sh +python main.py +``` + +During execution, the project will: +- Pull and stream model updates from ollama. +- Orchestrate web searches, model queries, and classification tasks. +- Maintain a conversation history for iterative improvements. + +## License +See [LICENSE](LICENSE) for details. + diff --git a/api/api.py b/api/api.py new file mode 100644 index 0000000..8d9d1ea --- /dev/null +++ b/api/api.py @@ -0,0 +1,128 @@ +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware + +from pydantic import BaseModel +from main import orchestrate, create_vector_store, load_and_chunk_file +from codeExecution import UserEnvironment +import conversation_store +import re + +from config import Config +from codeExecution import UserEnvironment, orchestrate_code +from queries import ( + perform_web_search, + rag_query, + classify_task, + MODEL_NAMES +) + +import debug as debugMod +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.vectorstores import Chroma +from langchain_community.embeddings import HuggingFaceEmbeddings +import conversation_store +conversation_store.initialize_db() + +app = FastAPI() + + +class UserRequest(BaseModel): + query: str + user_id: str + file_content: str = None + + +class CodeExecutionResponse(BaseModel): + result: str + error: str = None + links: list[str] = [] + + +# Maintain user environments and vector stores +user_environments = {} +user_vector_stores = {} + + +@app.post("/process-query") +async def process_query(request: UserRequest): + try: + # Initialize or retrieve user environment + if request.user_id not in user_environments: + user_environments[request.user_id] = UserEnvironment( + request.user_id) + user_vector_stores[request.user_id] = None + + # Handle file upload/content + chunks = [] + if request.file_content: + chunks = load_and_chunk_from_content(request.file_content) + user_vector_stores[request.user_id] = create_vector_store(chunks) + + # Process query + vector_store = user_vector_stores[request.user_id] + response, links = orchestrate( + request.query, + vector_store=vector_store, + comm_outp=lambda x: None, # Disable direct printing + comm_inp=lambda x: None # Disable direct input + ) + + # Handle code execution + code_blocks = re.findall( + Config.code_block_regex(), response, re.DOTALL) + execution_results = [] + if code_blocks: + execution_results = execute_code_blocks( + request.user_id, + code_blocks, + request.query, + response, + links, + chunks + ) + + return { + "response": response, + "links": links, + "execution_results": execution_results + } + except Exception as e: + debugMod.log(f"API error: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +def execute_code_blocks(user_id, code_blocks, query, response, links, chunks): + results = [] + user_env = user_environments[user_id] + vector_store = user_vector_stores[user_id] + + for code in code_blocks: + result = orchestrate_code( + orchestrate, + vector_store, + chunks, + user_env, + [code], + query, + response, + links + ) + results.append(result) + return results + + +def load_and_chunk_from_content(content: str): + splitter = RecursiveCharacterTextSplitter( + chunk_size=Config.CHUNK_SIZE, + chunk_overlap=Config.CHUNK_OVERLAP + ) + return splitter.split_text(content) + + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) diff --git a/codeExecution.py b/codeExecution.py new file mode 100644 index 0000000..d64d65e --- /dev/null +++ b/codeExecution.py @@ -0,0 +1,155 @@ +import tempfile +from pathlib import Path +import re +from types import FunctionType +import docker + +import debug as debugMod +import conversation_store +from config import Config + + +class UserEnvironment: + def __init__(self, user_id: str): + self.user_id = user_id + self.client = docker.from_env() + self.temp_dir = tempfile.TemporaryDirectory(prefix=f"{user_id}_code_") + + def execute_code(self, code: str, context=None, timeout=15, memory_limit=100): + # Validate input + if len(code) > Config.MAX_CODE_LENGTH: + return {"err": "Code exceeds length limit"} + + # Create temp file + tmp_path = Path(self.temp_dir.name) / "script.py" + with open(tmp_path, "w") as f: + if context: + f.write(f"context = {repr(context)}\n") + f.write(code) + + container = None + + try: + # Execute in container + container = self.client.containers.run( + image="code-sandbox", + command=[], + volumes={ + str(tmp_path): {'bind': '/sandbox/script.py', 'mode': 'ro'} + }, + mem_limit=f"{memory_limit}m", + cpu_period=100000, + cpu_quota=50000, # Limit to 50% CPU + network_mode='none', + user='sandboxuser', + read_only=True, + security_opt=['no-new-privileges'], + cap_drop=['ALL'], + detach=True, + stdout=True, + stderr=True, + timeout=timeout + ) + + # Wait for completion + result = container.wait(timeout=timeout + 5) + logs = container.logs().decode() + + # Parse results + output = logs[:Config.OUTPUT_CHAR_LIMIT] + error = None + + if result['StatusCode'] != 0: + error = f"Container exited with code {result['StatusCode']}" + + return { + "output": output, + "error": error, + "status": result['StatusCode'] + } + + except docker.errors.ContainerError as e: + return {"err": f"Container error: {str(e)}"} + except docker.errors.DockerException as e: + return {"err": f"Docker error: {str(e)}"} + except Exception as e: + return {"err": f"Execution failed: {str(e)}"} + finally: + tmp_path.unlink(missing_ok=True) + if container: + try: + container.remove(force=True) + except docker.errors.NotFound: + pass + + def cleanup(self): + self.temp_dir.cleanup() + + +def orchestrate_code(orchestrate: FunctionType, vector_store, chunks, user_env: UserEnvironment, code_blocks, query, response, links): + debugMod.log("\nExecuting code...\n") + + for code in code_blocks: + retry_count = 0 + current_code = code.strip() + last_error = None + + while retry_count < Config.MAX_CODE_RETRIES: + execution_result = user_env.execute_code( + current_code, context=chunks if chunks else None) + + if isinstance(execution_result, dict) and 'err' in execution_result: + # hard code to let user know the program didn't explode + debugMod.log( + "\n\nhmmm...looks like this code didn't work properly, I'll try debugging it now!\n") + + last_error = execution_result['err'] + debugMod.log(f"\nExecution error: {last_error}\n") + + # Generate fix prompt using full orchestration + fix_prompt = f"""Fix this Python code. Error: {last_error} + Code: + ```python + {current_code} + ``` + Requirements: + 1. Preserve original functionality + 2. Explain fixes in comments + 3. Return ONLY corrected code in a single Python block""" + + [fixed_response, _] = orchestrate(fix_prompt, vector_store) + new_blocks = re.findall( + Config.code_block_regex(), fixed_response, re.DOTALL) + + if new_blocks: + current_code = new_blocks[0].strip() + retry_count += 1 + debugMod.log(f"\nRetry #{retry_count} with modified code\n") + else: + break + else: + debugMod.log("\nCode Execution Result:\n", execution_result) + if execution_result: + # Get current conversation ID after saving conversation + conv_id = conversation_store.save_conversation(query, response, links) + + # Save code execution with context + conversation_store.save_code_execution( + code=current_code, + result=execution_result, + error=execution_result.get('err') if isinstance( + execution_result, dict) else None, + retries=retry_count, + conversation_id=conv_id + ) + break + + if last_error and retry_count >= Config.MAX_CODE_RETRIES: + debugMod.log( + f"\nFailed to fix after {Config.MAX_CODE_RETRIES} attempts. Final error: {last_error}\n") + # Request human intervention via orchestration + help_response = orchestrate( + f"Explain this code error to user: {last_error}", + vector_store + )[0] + debugMod.log(help_response + "\n") diff --git a/config.py b/config.py new file mode 100644 index 0000000..e986521 --- /dev/null +++ b/config.py @@ -0,0 +1,61 @@ +import os + + +class Config: + # === System Architecture === + USER_ID = "ION606" + LOG_DIR = "logs" + + # === Code Execution === + CODE_TIMEOUT = 15 # Seconds before killing process + CODE_MEMORY_LIMIT_MB = 100 # Memory limit for subprocesses + OUTPUT_CHAR_LIMIT = 2000 # Max characters for stdout/stderr + MAX_CODE_RETRIES = 3 # Max code fix attempts + CORE_DUMP_LIMIT = 0 # Disable core dumps + MAX_CODE_LENGTH = 10000 # Character limit for code inputs + + # === Orchestration === + MAX_ORCHESTRATION_ITERATIONS = 5 # Max loops in orchestrate() + SEARCH_CONTEXT_LIMIT = 3000 # Characters for web context + LOCAL_CONTEXT_ITEMS = 3 # File chunks to include + + # === Text Processing === + CHUNK_SIZE = 1000 # Document chunking size + CHUNK_OVERLAP = 200 # Document chunk overlap + RESPONSE_TOKEN_LIMIT = 4096 # Max LLM response tokens + + # === Search === + MAX_SEARCH_RESULTS = 3 # Web results per query + MIN_SEARCH_QUERY_LEN = 3 # Minimum search terms + + # === Conversation Store === + CONVERSATION_DB_PATH = os.path.join( + os.path.dirname(__file__), "data/conversations.db") + MAX_QUERY_LENGTH = 2000 # Characters for stored queries + MAX_RESPONSE_LENGTH = 10000 # Characters for stored responses + + # === Model Settings === + MODEL_TEMPERATURE = 0.7 # Default creativity level + MAX_CLASSIFY_ATTEMPTS = 3 # Task classification retries + + # === Safety Limits === + MAX_INPUT_LENGTH = 15000 # Absolute input character limit + MAX_CONTEXT_DEPTH = 5 # Max nested context references + + # === Formatting === + @staticmethod + def valid_range(num_options: int) -> str: + return f"0-{num_options - 1}" + + @staticmethod + def code_block_regex(): + return r'```python(.*?)```' + + # === Path Helpers === + @staticmethod + def chroma_path(): + return "./data/chroma_db" + + @staticmethod + def debug_log_path(): + return f"{Config.LOG_DIR}/debug.txt" diff --git a/conversation_store.py b/conversation_store.py new file mode 100644 index 0000000..0c4655c --- /dev/null +++ b/conversation_store.py @@ -0,0 +1,84 @@ +from datetime import datetime +import sqlite3 +import os + +import debug as debugModule +from config import Config + + +def initialize_db(): + conn = sqlite3.connect(Config.CONVERSATION_DB_PATH) + c = conn.cursor() + + # Create conversations table + c.execute('''CREATE TABLE IF NOT EXISTS conversations + (id INTEGER PRIMARY KEY AUTOINCREMENT, + query TEXT, + response TEXT, + links TEXT, + timestamp DATETIME)''') + + # Create code_executions table + c.execute('''CREATE TABLE IF NOT EXISTS code_executions + (id INTEGER PRIMARY KEY AUTOINCREMENT, + conversation_id INTEGER, + code_content TEXT, + execution_result TEXT, + error_message TEXT, + retry_count INTEGER, + timestamp DATETIME, + FOREIGN KEY(conversation_id) REFERENCES conversations(id))''') + + conn.commit() + conn.close() + debugModule.log("database initialized successfully!") + + +def save_conversation(query, response, links): + conn = sqlite3.connect(Config.CONVERSATION_DB_PATH) + c = conn.cursor() + c.execute('''INSERT INTO conversations + (query, response, links, timestamp) + VALUES (?, ?, ?, ?)''', + (query, response, ", ".join(links), datetime.now())) + conn.commit() + conn.close() + return c.lastrowid # Return the generated conversation ID because it's easier than uuids + + +def save_code_execution(code, result, error=None, retries=0, conversation_id=None): + conn = sqlite3.connect(Config.CONVERSATION_DB_PATH) + c = conn.cursor() + + # Convert execution result to string if needed + if isinstance(result, dict): + execution_result = str(result.get('err', '')) + error_message = result.get('err', '') + else: + execution_result = str(result) + error_message = error or '' + + c.execute('''INSERT INTO code_executions + (conversation_id, code_content, execution_result, + error_message, retry_count, timestamp) + VALUES (?, ?, ?, ?, ?, ?)''', + (conversation_id, code, execution_result, + error_message, retries, datetime.datetime.now())) + + conn.commit() + conn.close() + + +def save_conversation(query, response, links): + if len(query) > Config.MAX_QUERY_LENGTH or len(response) > Config.MAX_RESPONSE_LENGTH: + raise ValueError("Input too large") + + conn = sqlite3.connect(Config.CONVERSATION_DB_PATH) + cursor = conn.cursor() + links_str = ", ".join(links) if links else "" + cursor.execute( + "INSERT INTO conversations (query, response, links) VALUES (?, ?, ?)", + (query, response, links_str) + ) + conn.commit() + conn.close() diff --git a/debug.py b/debug.py new file mode 100644 index 0000000..2ea2be7 --- /dev/null +++ b/debug.py @@ -0,0 +1,31 @@ +import datetime +import os +from config import Config + + +def log(txt, wrapped=False): + try: + if not os.path.exists(Config.LOG_DIR): + os.makedirs(Config.LOG_DIR) + + with open(Config.debug_log_path(), 'a') as f: + if wrapped: + f.write('==============================================\n') + + timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + lines = txt.split('\n') + for i, line in enumerate(lines): + prefix = f"{timestamp} - " if i == 0 else " " * \ + len(timestamp) + " " + f.writelines([f"{prefix}{line}\n"]) + + if wrapped: + f.write('==============================================\n') + except Exception as e: + print(f"Failed to write to debug.txt: {e}") + + +def moveDebugLog(): + deblogpath = os.path.join(Config.LOG_DIR, 'debug.txt') + if os.path.exists(Config.debug_log_path()): + os.rename(deblogpath, os.path.join(Config.LOG_DIR, 'old-debug.txt')) diff --git a/main.py b/main.py new file mode 100644 index 0000000..2bbc0c1 --- /dev/null +++ b/main.py @@ -0,0 +1,224 @@ +from codeExecution import UserEnvironment, orchestrate_code +from queries import ( + perform_web_search, + rag_query, + classify_task, + MODEL_NAMES +) +import debug as debugMod +from search import perform_web_search +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.vectorstores import Chroma +from langchain_community.embeddings import HuggingFaceEmbeddings +import os +import argparse +import re +import ollama +from config import Config +import conversation_store +conversation_store.initialize_db() + +ollama.Client(host='http://ollama:11434') + +# Just in case +for complexity, model_name in MODEL_NAMES.items(): + print(f"Pulling {complexity} model ({model_name})...") + + # Stream the pull process and print progress + for progress in ollama.pull(model=model_name, stream=True): + if "status" in progress: + print(f" {progress['status']}", end="\r") # Overwrite the same line + if "completed" in progress and "total" in progress: + # Calculate and print download percentage + percent = (progress["completed"] / progress["total"]) * 100 + print(f" Downloading: {percent:.1f}% complete", end="\r") + + print("\nDone!") # Newline after each model is pulled + + +def load_and_chunk_file(file_path): + debugMod.log(f"Loading and chunking file: {file_path}") + if not os.path.exists(file_path): + raise FileNotFoundError(f"File {file_path} not found") + + with open(file_path, "r") as f: + text = f.read() + splitter = RecursiveCharacterTextSplitter( + chunk_size=Config.CHUNK_SIZE, chunk_overlap=Config.CHUNK_OVERLAP) + chunks = splitter.split_text(text) + debugMod.log(f"File chunked into {len(chunks)} chunks") + return chunks + + +def create_vector_store(chunks): + debugMod.log("Creating vector store") + if not chunks: + debugMod.log("No chunks provided, returning None") + return None + embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") + + vector_store = Chroma.from_texts( + chunks, + embeddings, + persist_directory=Config.chroma_path() + ) + + debugMod.log("Vector store created") + return vector_store + + +def orchestrate(query, vector_store=None, comm_outp=print, comm_inp=input): + debugMod.log(f"Orchestrating query: {query}") + aggregated_web_context = "" + local_context = "" + user_context = "" + response_context = "" + links = [] + + # Classify task once at start + task_type = classify_task(query) + debugMod.log(f"Task classified as: {task_type}") + + # Early exit for simple tasks + if task_type == "simple": + debugMod.log("Direct response for simple task") + return [rag_query(query, task_type=task_type), []] + + # Initialize context for medium/complex tasks + if vector_store: + docs = vector_store.similarity_search(query, k=3) + local_context = "\n".join( + [d.page_content for d in docs]) if docs else "" + debugMod.log(f"Local context: {local_context}") + + iteration = 0 + status = "continue" + + while iteration < Config.MAX_ORCHESTRATION_ITERATIONS and status != "final": + debugMod.log(f"--- Iteration {iteration} [Status: {status}] ---") + response = "" + + if status == "continue": + # Include previous responses in reflection + reflection_prompt = f"""Determine the next action needed to answer: {query} + + Available actions: + 1. web_search - Needs web information + 2. user_input - Requires clarification + 3. final_response - Ready to answer + + Context: + - Web: {aggregated_web_context} + - Local: {local_context} + - User: {user_context} + - Previous Responses: {response_context} + + Return ONLY: web_search/user_input/final_response""" + + status = rag_query( + reflection_prompt, task_type=task_type, silent=True).strip().lower() + debugMod.log(f"Action determined: {status}") + + if status == "web_search": + search_prompt = f"""Generate search query considering: {query} + Previous responses: {response_context} + Return ONLY search terms""" + + search_terms = rag_query( + search_prompt, task_type=task_type, silent=True).strip('"') + debugMod.log(f"Searching web for: {search_terms}") + + web_results, new_links = perform_web_search(search_terms) + links.extend(new_links) + + if web_results: + aggregated_web_context += f"\nWeb: {web_results}" + debugMod.log(f"Updated web context") + + elif status == "user_input": + comm_outp("\n[System] Additional info needed:") + user_input = comm_inp("Please clarify: ") + user_context += f"\nUser input: {user_input}" + debugMod.log(f"Received user input") + status = "continue" + + elif status == "final_response": + break + + else: + debugMod.log(f"Unknown status: {status}") + status = "final_response" + + # Generate and store response + if status != "final_response": + response = rag_query( + query, + task_type=task_type, + web_context=aggregated_web_context, + local_context=local_context, + user_context=user_context, + response_context=response_context # Pass previous responses + ) + response_context += f"\nIteration {iteration} response: {response}" + debugMod.log(f"Iteration {iteration} response stored") + + iteration += 1 + + # Generate final response with full context + final_response = rag_query( + f"Final answer considering: {query}", + task_type=task_type, + web_context=aggregated_web_context, + local_context=local_context, + user_context=user_context, + response_context=response_context + ) + + debugMod.log("Orchestration completed") + return [final_response, links] + + +if __name__ == "__main__": + debugMod.moveDebugLog() + + parser = argparse.ArgumentParser() + parser.add_argument('--file', type=str, default="", + help='Path to data file for analysis') + parser.add_argument('--cli', type=str, default="false", + help="whether to use the CLI for input or run the API") + args = parser.parse_args() + + vector_store = None + chunks = [] + + if args.file: + try: + debugMod.log(f"Loading file: {args.file}") + chunks = load_and_chunk_file(args.file) + vector_store = create_vector_store(chunks) + except Exception as e: + debugMod.log(f"Error loading file: {str(e)}") + chunks = [] + + user_env = UserEnvironment("ION606") + + # if args.cli: + while True: + query = input("\nEnter your query (type 'exit' to quit): ") + + if query.lower() == 'exit': + break + + [response, links] = orchestrate(query, vector_store) + + if len(links) > 0: + print(f"links: {", ".join(links)}") + + # Save conversation to SQLite + conversation_store.save_conversation(query, response, links) + + # code + code_blocks = re.findall(Config.code_block_regex(), response, re.DOTALL) + if code_blocks: + orchestrate_code(orchestrate, vector_store, chunks, + user_env, code_blocks, query, response, links) diff --git a/queries.py b/queries.py new file mode 100644 index 0000000..52294fd --- /dev/null +++ b/queries.py @@ -0,0 +1,186 @@ +import debug as debugMod +from search import perform_web_search +import ollama +import conversation_store +conversation_store.initialize_db() + +# models: better: qwen2.5-coder:14b, faster: phi3 (but worse), with more processing power: deepseek-r1:32b +MODEL_NAMES = { + "classification": "dolphin3:8b", # Best for structured tasks + "simple": "phi3:latest", # phi3:mini + "medium": "llama3:8b-instruct-q8_0", + "complex": "deepseek-coder:33b-instruct-q4_K_M" +} + + +def classify_task(query: str) -> str: + # Use a tiny model to classify the task + prompt = f"""Classify this query into one of these categories: + - "simple": greetings, yes/no, basic facts + - "medium": summarization, simple coding + - "complex": advanced coding, data analysis, multi-step reasoning + + Query: {query} + Return ONLY the category name (e.g., "simple").""" + + toPassIn = "" + for i in range(3): + response = ollama.chat(model=MODEL_NAMES["classification"], messages=[ + {"role": "user", "content": prompt + toPassIn}]) + task_type = response["message"]["content"].strip().lower() + if (task_type in MODEL_NAMES.keys()): + return task_type + else: + toPassIn += f"\nthe last response '{response}' was incorrect (AKA not one of {MODEL_NAMES.keys()}), try again and pick one of these based on the above" + + return 'complex' + + +def generate_prompt(query, web_context, local_context, user_context, response_context, onlyRules=False): + prompt = f""" + **Strict Response Rules** + 1. Greetings & Casual Queries: + - For greetings (e.g. "good morning", "hello"): + * Respond with ONLY a short friendly acknowledgment + * NEVER explain why you can't chat casually + * Example: "Good morning! How can I assist you today?" + + 2. Technical Responses: + - Generate code ONLY if: + * User explicitly requests technical help + * Local file context exists for data analysis tasks + - Keep code explanations concise (1-2 sentences max) + + 3. Web Search Policy: + - NEVER search for greetings/casual conversation + - Search only when: + * Technical info is needed + * Local data is insufficient + + 4. Formatting: + - NO markdown/bullets in casual responses + - NO internal system references (e.g. "Technilopia Forum") + - NO justification of rules to users + - NEVER include the user's question unless explicitly asked to do so + - NEVER include previous responses + + {f'Local File Context: {local_context}' if local_context else ''} + """ + + if onlyRules: + return prompt + + prompt = f""" + Context Sources:\n + {f'[WEB] {web_context}' if web_context else ''}\n + {f'[LOCAL FILE] {local_context}' if local_context else ''}\n + {f'[USER CONTEXT] {user_context}' if user_context else ''}\n + \n[PREVIOUS RESPONSES] {response_context}\n + Question: {query} + + {prompt} + """ + return prompt + + +def call_ollama_and_print(task_type, prompt, silent=False): + if silent: + response = ollama.chat(model=MODEL_NAMES[task_type], messages=[ + {"role": "user", "content": prompt}]) + debugMod.log("RAG query response received") + return response + + full_response = "" + print("\nAI Response: ", end="", flush=True) # Start response line + + # Stream the response + stream = ollama.chat( + model=MODEL_NAMES[task_type], + messages=[{"role": "user", "content": prompt}], + stream=True + ) + + for chunk in stream: + content = chunk.get('message', {}).get('content', '') + print(content, end="", flush=True) # Stream to terminal + full_response += content + + print() # Newline after streaming + debugMod.log("RAG query response received") + return full_response + + +def multi_choice_query(query, options: list[str], task_type: str, web_context="", local_context="", user_context="", silent=False): + attempts = 0 + max_attempts = 3 + inds = list(range(len(options))) + valid_range = f"0-{len(inds) - 1}" + last_error = "" + + debugMod.log( + f"Multi-choice query with options: {', '.join([f'{i}: {opt}' for i, opt in enumerate(options)])}") + + while attempts < max_attempts: + prompt = f"""Return ONLY the numeric index ({valid_range}) for the best option. Invalid responses will be rejected. + + Available Options: + {"\n".join([f"{i}: {option}" for i, option in enumerate(options)])} + + Question: {query} + + Context Sources: + {f'[WEB] {web_context}' if web_context else ''} + {f'[LOCAL] {local_context}' if local_context else ''} + {f'[USER] {user_context}' if user_context else ''} + + {generate_prompt(query, web_context, local_context, user_context, onlyRules=True)} + - You MUST return a SINGLE INTEGER between {valid_range} + - DO NOT include explanations or punctuation""" + + if last_error: + prompt += f"\n\nPrevious invalid response: {last_error}" + + try: + content = call_ollama_and_print(task_type, prompt, silent) + debugMod.log(f"Multi-choice response: {content}", wrapped=True) + + # Strict validation + if not content.isdigit(): + raise ValueError(f"Non-numeric response: {content}") + + ind = int(content) + + if 0 <= ind < len(options): + debugMod.log(f"Valid choice selected: {ind} ({options[ind]})") + return options[ind] + + raise IndexError(f"Index {ind} out of range {valid_range}") + + except (ValueError, IndexError) as e: + last_error = str(e) + debugMod.log(f"Validation failed: {last_error}") + attempts += 1 + continue + + except Exception as e: + debugMod.log(f"Unexpected error: {str(e)}") + attempts += 1 + continue + + # Fallback to safest option after all attempts + debugMod.log(f"All attempts failed. Defaulting to first option") + return options[0] + + +def rag_query(query, task_type: str = None, web_context="", local_context="", user_context="", response_context="", silent=False): + # Model selection logic + task_type = classify_task(query) if not task_type else task_type + + debugMod.log(f"Generating {task_type} RAG query with query: {query}") + prompt = generate_prompt( + query, web_context, local_context, user_context, response_context) + + response = call_ollama_and_print(task_type, prompt, silent) + + # if it's not silent, then it'll return a string + return response["message"]["content"] if silent else response diff --git a/search.py b/search.py new file mode 100644 index 0000000..65b569b --- /dev/null +++ b/search.py @@ -0,0 +1,37 @@ +import os +import dotenv + +from config import Config + +dotenv.load_dotenv('./.env') + +from langchain.tools import Tool +from langchain_community.utilities import GoogleSerperAPIWrapper + +# Set up the search tool +search = GoogleSerperAPIWrapper(serper_api_key=os.getenv("SERPER_API_KEY")) +tool = Tool( + name="Web Search", + func=search.run, + description="Useful for finding real-time information on the web." +) + + +def perform_web_search(query: str): + try: + if not query or query.strip() == "" or query.strip().lower() == 'none': + return ["", []] + + results = tool.run(query.strip()) + if not results: + return ["No web results found", []] + + # Extract snippets and links from search results + if 'organic' in results: + snippets = "\n".join([f"{res['title']}: {res['snippet']}" + for res in results['organic'][:Config.MAX_SEARCH_RESULTS]]) + links = [res['link'] for res in results['organic'][:Config.MAX_SEARCH_RESULTS]] + return [snippets, links] + return [str(results), []] + except Exception as e: + return [f"Search error: {str(e)}", []] \ No newline at end of file