mirror of
https://github.com/ION606/ML-pipeline.git
synced 2026-05-14 21:06:54 +00:00
129 lines
3.4 KiB
Python
129 lines
3.4 KiB
Python
|
|
from fastapi import FastAPI, HTTPException
|
||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
||
|
|
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from main import orchestrate, create_vector_store, load_and_chunk_file
|
||
|
|
from codeExecution import UserEnvironment
|
||
|
|
import conversation_store
|
||
|
|
import re
|
||
|
|
|
||
|
|
from config import Config
|
||
|
|
from codeExecution import UserEnvironment, orchestrate_code
|
||
|
|
from queries import (
|
||
|
|
perform_web_search,
|
||
|
|
rag_query,
|
||
|
|
classify_task,
|
||
|
|
MODEL_NAMES
|
||
|
|
)
|
||
|
|
|
||
|
|
import debug as debugMod
|
||
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
|
|
from langchain_community.vectorstores import Chroma
|
||
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||
|
|
import conversation_store
|
||
|
|
conversation_store.initialize_db()
|
||
|
|
|
||
|
|
app = FastAPI()
|
||
|
|
|
||
|
|
|
||
|
|
class UserRequest(BaseModel):
|
||
|
|
query: str
|
||
|
|
user_id: str
|
||
|
|
file_content: str = None
|
||
|
|
|
||
|
|
|
||
|
|
class CodeExecutionResponse(BaseModel):
|
||
|
|
result: str
|
||
|
|
error: str = None
|
||
|
|
links: list[str] = []
|
||
|
|
|
||
|
|
|
||
|
|
# Maintain user environments and vector stores
|
||
|
|
user_environments = {}
|
||
|
|
user_vector_stores = {}
|
||
|
|
|
||
|
|
|
||
|
|
@app.post("/process-query")
|
||
|
|
async def process_query(request: UserRequest):
|
||
|
|
try:
|
||
|
|
# Initialize or retrieve user environment
|
||
|
|
if request.user_id not in user_environments:
|
||
|
|
user_environments[request.user_id] = UserEnvironment(
|
||
|
|
request.user_id)
|
||
|
|
user_vector_stores[request.user_id] = None
|
||
|
|
|
||
|
|
# Handle file upload/content
|
||
|
|
chunks = []
|
||
|
|
if request.file_content:
|
||
|
|
chunks = load_and_chunk_from_content(request.file_content)
|
||
|
|
user_vector_stores[request.user_id] = create_vector_store(chunks)
|
||
|
|
|
||
|
|
# Process query
|
||
|
|
vector_store = user_vector_stores[request.user_id]
|
||
|
|
response, links = orchestrate(
|
||
|
|
request.query,
|
||
|
|
vector_store=vector_store,
|
||
|
|
comm_outp=lambda x: None, # Disable direct printing
|
||
|
|
comm_inp=lambda x: None # Disable direct input
|
||
|
|
)
|
||
|
|
|
||
|
|
# Handle code execution
|
||
|
|
code_blocks = re.findall(
|
||
|
|
Config.code_block_regex(), response, re.DOTALL)
|
||
|
|
execution_results = []
|
||
|
|
if code_blocks:
|
||
|
|
execution_results = execute_code_blocks(
|
||
|
|
request.user_id,
|
||
|
|
code_blocks,
|
||
|
|
request.query,
|
||
|
|
response,
|
||
|
|
links,
|
||
|
|
chunks
|
||
|
|
)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"response": response,
|
||
|
|
"links": links,
|
||
|
|
"execution_results": execution_results
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
debugMod.log(f"API error: {str(e)}")
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
|
||
|
|
def execute_code_blocks(user_id, code_blocks, query, response, links, chunks):
|
||
|
|
results = []
|
||
|
|
user_env = user_environments[user_id]
|
||
|
|
vector_store = user_vector_stores[user_id]
|
||
|
|
|
||
|
|
for code in code_blocks:
|
||
|
|
result = orchestrate_code(
|
||
|
|
orchestrate,
|
||
|
|
vector_store,
|
||
|
|
chunks,
|
||
|
|
user_env,
|
||
|
|
[code],
|
||
|
|
query,
|
||
|
|
response,
|
||
|
|
links
|
||
|
|
)
|
||
|
|
results.append(result)
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
def load_and_chunk_from_content(content: str):
|
||
|
|
splitter = RecursiveCharacterTextSplitter(
|
||
|
|
chunk_size=Config.CHUNK_SIZE,
|
||
|
|
chunk_overlap=Config.CHUNK_OVERLAP
|
||
|
|
)
|
||
|
|
return splitter.split_text(content)
|
||
|
|
|
||
|
|
|
||
|
|
app.add_middleware(
|
||
|
|
CORSMiddleware,
|
||
|
|
allow_origins=["*"],
|
||
|
|
allow_credentials=True,
|
||
|
|
allow_methods=["*"],
|
||
|
|
allow_headers=["*"],
|
||
|
|
)
|