Files
ML-pipeline/api/api.py
T

129 lines
3.4 KiB
Python
Raw Normal View History

2025-04-01 22:29:59 -04:00
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from main import orchestrate, create_vector_store, load_and_chunk_file
from codeExecution import UserEnvironment
import conversation_store
import re
from config import Config
from codeExecution import UserEnvironment, orchestrate_code
from queries import (
perform_web_search,
rag_query,
classify_task,
MODEL_NAMES
)
import debug as debugMod
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
import conversation_store
conversation_store.initialize_db()
app = FastAPI()
class UserRequest(BaseModel):
query: str
user_id: str
file_content: str = None
class CodeExecutionResponse(BaseModel):
result: str
error: str = None
links: list[str] = []
# Maintain user environments and vector stores
user_environments = {}
user_vector_stores = {}
@app.post("/process-query")
async def process_query(request: UserRequest):
try:
# Initialize or retrieve user environment
if request.user_id not in user_environments:
user_environments[request.user_id] = UserEnvironment(
request.user_id)
user_vector_stores[request.user_id] = None
# Handle file upload/content
chunks = []
if request.file_content:
chunks = load_and_chunk_from_content(request.file_content)
user_vector_stores[request.user_id] = create_vector_store(chunks)
# Process query
vector_store = user_vector_stores[request.user_id]
response, links = orchestrate(
request.query,
vector_store=vector_store,
comm_outp=lambda x: None, # Disable direct printing
comm_inp=lambda x: None # Disable direct input
)
# Handle code execution
code_blocks = re.findall(
Config.code_block_regex(), response, re.DOTALL)
execution_results = []
if code_blocks:
execution_results = execute_code_blocks(
request.user_id,
code_blocks,
request.query,
response,
links,
chunks
)
return {
"response": response,
"links": links,
"execution_results": execution_results
}
except Exception as e:
debugMod.log(f"API error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
def execute_code_blocks(user_id, code_blocks, query, response, links, chunks):
results = []
user_env = user_environments[user_id]
vector_store = user_vector_stores[user_id]
for code in code_blocks:
result = orchestrate_code(
orchestrate,
vector_store,
chunks,
user_env,
[code],
query,
response,
links
)
results.append(result)
return results
def load_and_chunk_from_content(content: str):
splitter = RecursiveCharacterTextSplitter(
chunk_size=Config.CHUNK_SIZE,
chunk_overlap=Config.CHUNK_OVERLAP
)
return splitter.split_text(content)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)