289 lines
9.2 KiB
TypeScript
289 lines
9.2 KiB
TypeScript
import { serve } from "bun";
|
|
import fs from "node:fs";
|
|
import path from "node:path";
|
|
|
|
// types
|
|
interface Chunk {
|
|
id: string;
|
|
text: string;
|
|
metadata?: Record<string, unknown>;
|
|
vector: number[];
|
|
}
|
|
|
|
interface Collection {
|
|
name: string;
|
|
chunks: Chunk[];
|
|
}
|
|
|
|
interface OllamaChatMessage {
|
|
role: "system" | "user" | "assistant";
|
|
content: string;
|
|
}
|
|
|
|
interface OllamaChatRequest {
|
|
model?: string;
|
|
messages: OllamaChatMessage[];
|
|
stream?: boolean;
|
|
}
|
|
|
|
interface OllamaChatResponse {
|
|
message?: OllamaChatMessage;
|
|
[k: string]: unknown;
|
|
}
|
|
|
|
interface UpsertInputItem {
|
|
text: string;
|
|
metadata?: Record<string, unknown>;
|
|
}
|
|
|
|
interface OpenAPIObject {
|
|
openapi: string;
|
|
info: { title: string; version: string };
|
|
paths: Record<string, unknown>;
|
|
}
|
|
|
|
// env
|
|
const PORT: number = Number(process.env.PORT || 8788),
|
|
HOST: string = process.env.HOST || "0.0.0.0",
|
|
OLLAMA_BASE: string = process.env.OLLAMA_BASE || "http://localhost:11434",
|
|
OLLAMA_CHAT_MODEL: string = process.env.OLLAMA_CHAT_MODEL || "llama3.1",
|
|
OLLAMA_EMBED_MODEL: string = process.env.OLLAMA_EMBED_MODEL || "nomic-embed-text",
|
|
DATA_DIR: string = process.env.DATA_DIR || path.resolve("./data"),
|
|
SNAPSHOT: string = path.join(DATA_DIR, "rag.json");
|
|
|
|
// in-memory db
|
|
const db: Map<string, Collection> = new Map();
|
|
|
|
// util: smol json persistence
|
|
function ensureDirs(): void {
|
|
if (!fs.existsSync(DATA_DIR)) fs.mkdirSync(DATA_DIR, { recursive: true });
|
|
}
|
|
|
|
// you can probably guess
|
|
function loadSnapshot(): void {
|
|
try {
|
|
ensureDirs();
|
|
if (fs.existsSync(SNAPSHOT)) {
|
|
const raw = fs.readFileSync(SNAPSHOT, "utf8");
|
|
const obj = JSON.parse(raw || "{}") as Record<string, Collection>;
|
|
for (const [name, value] of Object.entries(obj)) db.set(name, value);
|
|
}
|
|
} catch (e) {
|
|
console.warn("failed to load snapshot:", e);
|
|
}
|
|
}
|
|
|
|
// you can probably guess 2
|
|
function saveSnapshot(): void {
|
|
try {
|
|
ensureDirs();
|
|
const obj = Object.fromEntries(db.entries());
|
|
fs.writeFileSync(SNAPSHOT, JSON.stringify(obj, null, 2));
|
|
} catch (e) {
|
|
console.warn("failed to save snapshot:", e);
|
|
}
|
|
}
|
|
|
|
loadSnapshot();
|
|
|
|
// basic text splitter (recursive by punctuation, then by length)
|
|
function chunkText(text: string, maxLen = 800): string[] {
|
|
const parts = text
|
|
.split(/\n{2,}/g)
|
|
.flatMap(p => p.split(/(?<=[.!?])\s+/g))
|
|
.flatMap(s => s.length > maxLen ? s.match(new RegExp(`.{1,${maxLen}}`, "g")) || [] : [s])
|
|
.map(s => s.trim())
|
|
.filter(Boolean);
|
|
return parts;
|
|
}
|
|
|
|
// cosine similarity
|
|
function dot(a: number[], b: number[]): number { let s = 0; for (let i = 0; i < a.length; i++) s += (a[i] || 0) * (b[i] || 0); return s; }
|
|
function norm(a: number[]): number { return Math.sqrt(dot(a, a)); }
|
|
function cosineSim(a: number[], b: number[]): number { const d = dot(a, b), n = norm(a) * norm(b) || 1; return d / n; }
|
|
|
|
// call ollama embeddings
|
|
async function embedAll(texts: string[]): Promise<number[][]> {
|
|
const primary = await fetch(`${OLLAMA_BASE}/api/embed`, {
|
|
method: "POST",
|
|
headers: { "content-type": "application/json" },
|
|
body: JSON.stringify({ model: OLLAMA_EMBED_MODEL, input: texts })
|
|
});
|
|
|
|
if (primary.ok) {
|
|
const j: { embeddings: number[][] } = await primary.json();
|
|
return j.embeddings;
|
|
}
|
|
|
|
const results: number[][] = [];
|
|
for (const t of texts) {
|
|
const r = await fetch(`${OLLAMA_BASE}/api/embeddings`, {
|
|
method: "POST",
|
|
headers: { "content-type": "application/json" },
|
|
body: JSON.stringify({ model: OLLAMA_EMBED_MODEL, prompt: t })
|
|
});
|
|
|
|
if (!r.ok) throw new Error(`embed failed: ${r.status}`);
|
|
|
|
const j: { embedding: number[] } = await r.json();
|
|
results.push(j.embedding);
|
|
}
|
|
return results;
|
|
}
|
|
|
|
// call ollama chat/generate with retrieved context
|
|
async function ollamaChat(req: OllamaChatRequest): Promise<OllamaChatResponse> {
|
|
const res = await fetch(`${OLLAMA_BASE}/api/chat`, {
|
|
method: "POST",
|
|
headers: { "content-type": "application/json" },
|
|
body: JSON.stringify({ model: req.model || OLLAMA_CHAT_MODEL, messages: req.messages, stream: req.stream })
|
|
});
|
|
|
|
if (!res.ok) throw new Error(`ollama chat failed: ${res.status}`);
|
|
const j: OllamaChatResponse = await res.json();
|
|
|
|
return j;
|
|
}
|
|
|
|
// openapi for open webui tool integration
|
|
const OPENAPI: OpenAPIObject = {
|
|
openapi: "3.1.0",
|
|
info: { title: "RAG Server (Ollama)", version: "1.0.0" },
|
|
paths: {
|
|
"/collections": {
|
|
get: { operationId: "listCollections" },
|
|
post: { operationId: "createCollection" }
|
|
},
|
|
"/upsert": { post: { operationId: "upsert" } },
|
|
"/query": { post: { operationId: "query" } },
|
|
"/chat": { post: { operationId: "chat" } }
|
|
}
|
|
};
|
|
|
|
// tiny router
|
|
async function json<T = any>(req: Request): Promise<T> { try { return await req.json() as T; } catch { return {} as T; } }
|
|
function sendJson(_res: unknown, status: number, obj: unknown): Response {
|
|
return new Response(JSON.stringify(obj), { status, headers: { "content-type": "application/json; charset=utf-8" } });
|
|
}
|
|
|
|
async function handleCollections(req: Request): Promise<Response> {
|
|
if (req.method === "GET") {
|
|
return sendJson(null, 200, { collections: Array.from(db.keys()) });
|
|
}
|
|
|
|
if (req.method === "POST") {
|
|
const body = await json<{ name?: string }>(req),
|
|
name = String(body?.name || "").trim();
|
|
|
|
if (!name) return sendJson(null, 400, { error: "name required" });
|
|
if (!db.has(name)) db.set(name, { name, chunks: [] });
|
|
|
|
saveSnapshot();
|
|
return sendJson(null, 200, { ok: true });
|
|
}
|
|
|
|
return new Response("not found", { status: 404 });
|
|
}
|
|
|
|
async function handleUpsert(req: Request): Promise<Response> {
|
|
const body = await json<{ collection?: string; items?: UpsertInputItem[] }>(req),
|
|
collection = String(body?.collection || "").trim(),
|
|
items: UpsertInputItem[] = Array.isArray(body?.items) ? body.items : [];
|
|
|
|
if (!collection) return sendJson(null, 400, { error: "collection required" });
|
|
if (!db.has(collection)) db.set(collection, { name: collection, chunks: [] });
|
|
|
|
const col = db.get(collection)!,
|
|
chunksToIndex: { text: string; metadata?: Record<string, unknown>; _id: string }[] = [];
|
|
|
|
for (const it of items) {
|
|
const parts = chunkText(String(it.text || ""));
|
|
for (const p of parts) chunksToIndex.push({ text: p, metadata: it.metadata || {}, _id: crypto.randomUUID() });
|
|
}
|
|
|
|
const vecs = await embedAll(chunksToIndex.map(x => x.text));
|
|
for (let i = 0; i < chunksToIndex.length; i++) {
|
|
const item = chunksToIndex[i],
|
|
doc: Chunk = { id: item._id, text: item.text, metadata: item.metadata, vector: vecs[i] };
|
|
|
|
col.chunks.push(doc);
|
|
}
|
|
|
|
saveSnapshot();
|
|
return sendJson(null, 200, { ok: true, indexed: chunksToIndex.length });
|
|
}
|
|
|
|
async function handleQuery(req: Request): Promise<Response> {
|
|
const body = await json<{ collection?: string; query?: string; topK?: number }>(req),
|
|
collection = String(body?.collection || "").trim(),
|
|
query = String(body?.query || "").trim(),
|
|
topK = Number(body?.topK || 5);
|
|
|
|
if (!collection || !query) return sendJson(null, 400, { error: "collection and query required" });
|
|
|
|
const col = db.get(collection);
|
|
if (!col) return sendJson(null, 404, { error: "collection not found" });
|
|
|
|
const [qvec] = await embedAll([query]),
|
|
scored = col.chunks.map((c) => ({ c, score: cosineSim(qvec, c.vector) }))
|
|
.sort((a, b) => b.score - a.score)
|
|
.slice(0, topK)
|
|
.map(x => ({ id: x.c.id, text: x.c.text, metadata: x.c.metadata, score: x.score }));
|
|
return sendJson(null, 200, { matches: scored });
|
|
}
|
|
|
|
async function handleChat(req: Request): Promise<Response> {
|
|
const body = await json<{ collection?: string; query?: string; topK?: number; model?: string }>(req),
|
|
collection = String(body?.collection || "").trim(),
|
|
query = String(body?.query || "").trim(),
|
|
topK = Number(body?.topK || 5),
|
|
model = body?.model || OLLAMA_CHAT_MODEL;
|
|
|
|
if (!collection || !query) return sendJson(null, 400, { error: "collection and query required" });
|
|
|
|
const col = db.get(collection);
|
|
if (!col) return sendJson(null, 404, { error: "collection not found" });
|
|
|
|
const [qvec] = await embedAll([query]),
|
|
matches = col.chunks.map((c) => ({ c, score: cosineSim(qvec, c.vector) }))
|
|
.sort((a, b) => b.score - a.score)
|
|
.slice(0, topK);
|
|
|
|
const context = matches.map((m, i) => `[[doc ${i + 1} score=${m.score.toFixed(3)}]]\n${m.c.text}`).join("\n\n"),
|
|
system: string = `you are a helpful assistant. use ONLY the provided context to answer. if the answer isn't in the context, say you don't know. cite as [doc N].`,
|
|
user: string = `question: ${query}\n\ncontext:\n${context}`;
|
|
|
|
const out = await ollamaChat({ model, messages: [{ role: "system", content: system }, { role: "user", content: user }], stream: false });
|
|
return sendJson(null, 200, {
|
|
answer: out?.message?.content || "",
|
|
citations: matches.map((m, i) => ({ id: m.c.id, score: m.score, text: m.c.text }))
|
|
});
|
|
}
|
|
|
|
const pickFunc = (pathname: string) => {
|
|
switch (pathname) {
|
|
case "/collections":
|
|
return handleCollections;
|
|
case "/upsert":
|
|
return handleUpsert;
|
|
case "/query":
|
|
return handleQuery;
|
|
case "/chat":
|
|
return handleChat;
|
|
default:
|
|
return undefined;
|
|
}
|
|
}
|
|
|
|
const server = serve({
|
|
port: PORT,
|
|
hostname: HOST,
|
|
fetch: async (req: Request): Promise<Response> => {
|
|
const u = new URL(req.url);
|
|
if (req.method === "GET" && u.pathname === "/") return new Response("ok");
|
|
if (req.method === "GET" && u.pathname === "/openapi.json") return sendJson(null, 200, OPENAPI);
|
|
return pickFunc(u.pathname)?.call(req) || new Response("not found", { status: 404 });
|
|
}
|
|
});
|
|
|
|
console.log(`[rag] listening on http://${HOST}:${PORT}`); |