mirror of
https://github.com/ION606/ML-pipeline.git
synced 2026-05-14 21:06:54 +00:00
added rudamentary search and code execution
This commit is contained in:
+128
@@ -0,0 +1,128 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from pydantic import BaseModel
|
||||
from main import orchestrate, create_vector_store, load_and_chunk_file
|
||||
from codeExecution import UserEnvironment
|
||||
import conversation_store
|
||||
import re
|
||||
|
||||
from config import Config
|
||||
from codeExecution import UserEnvironment, orchestrate_code
|
||||
from queries import (
|
||||
perform_web_search,
|
||||
rag_query,
|
||||
classify_task,
|
||||
MODEL_NAMES
|
||||
)
|
||||
|
||||
import debug as debugMod
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
import conversation_store
|
||||
conversation_store.initialize_db()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class UserRequest(BaseModel):
|
||||
query: str
|
||||
user_id: str
|
||||
file_content: str = None
|
||||
|
||||
|
||||
class CodeExecutionResponse(BaseModel):
|
||||
result: str
|
||||
error: str = None
|
||||
links: list[str] = []
|
||||
|
||||
|
||||
# Maintain user environments and vector stores
|
||||
user_environments = {}
|
||||
user_vector_stores = {}
|
||||
|
||||
|
||||
@app.post("/process-query")
|
||||
async def process_query(request: UserRequest):
|
||||
try:
|
||||
# Initialize or retrieve user environment
|
||||
if request.user_id not in user_environments:
|
||||
user_environments[request.user_id] = UserEnvironment(
|
||||
request.user_id)
|
||||
user_vector_stores[request.user_id] = None
|
||||
|
||||
# Handle file upload/content
|
||||
chunks = []
|
||||
if request.file_content:
|
||||
chunks = load_and_chunk_from_content(request.file_content)
|
||||
user_vector_stores[request.user_id] = create_vector_store(chunks)
|
||||
|
||||
# Process query
|
||||
vector_store = user_vector_stores[request.user_id]
|
||||
response, links = orchestrate(
|
||||
request.query,
|
||||
vector_store=vector_store,
|
||||
comm_outp=lambda x: None, # Disable direct printing
|
||||
comm_inp=lambda x: None # Disable direct input
|
||||
)
|
||||
|
||||
# Handle code execution
|
||||
code_blocks = re.findall(
|
||||
Config.code_block_regex(), response, re.DOTALL)
|
||||
execution_results = []
|
||||
if code_blocks:
|
||||
execution_results = execute_code_blocks(
|
||||
request.user_id,
|
||||
code_blocks,
|
||||
request.query,
|
||||
response,
|
||||
links,
|
||||
chunks
|
||||
)
|
||||
|
||||
return {
|
||||
"response": response,
|
||||
"links": links,
|
||||
"execution_results": execution_results
|
||||
}
|
||||
except Exception as e:
|
||||
debugMod.log(f"API error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def execute_code_blocks(user_id, code_blocks, query, response, links, chunks):
|
||||
results = []
|
||||
user_env = user_environments[user_id]
|
||||
vector_store = user_vector_stores[user_id]
|
||||
|
||||
for code in code_blocks:
|
||||
result = orchestrate_code(
|
||||
orchestrate,
|
||||
vector_store,
|
||||
chunks,
|
||||
user_env,
|
||||
[code],
|
||||
query,
|
||||
response,
|
||||
links
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def load_and_chunk_from_content(content: str):
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=Config.CHUNK_SIZE,
|
||||
chunk_overlap=Config.CHUNK_OVERLAP
|
||||
)
|
||||
return splitter.split_text(content)
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
Reference in New Issue
Block a user