embedwith/nodeapp/run.js

371 lines
9.4 KiB
JavaScript

import http from 'http'
import url from 'url'
import crypto from 'crypto'
import child_process from 'node:child_process'
import util from 'node:util'
const execFile = util.promisify(child_process.execFile)
import { DataType } from '@zilliz/milvus2-sdk-node'
import fs from 'fs'
import path from 'path'
const CONFIG = JSON.parse(fs.readFileSync(path.resolve(import.meta.dirname, './config.json'), 'utf8'))
const { MODELKEY_SERVER_MAP, MILVUS_HOST } = CONFIG
// init a single Milvus client for this process lifecycle
import { async_get_client } from './milvus_utils.js'
const client = await async_get_client(MILVUS_HOST)
process.on('exit', (code) => {
client.closeConnection()
return
})
const PORT = 8801
const HOST = '0.0.0.0'
const get_siteroot_html = function () {return `
<!DOCTYPE html>
<html lang=en>
<head>
<meta charset=utf-8>
<link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg'/>">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>${'embedwith'}</title>
</head>
<body>
<script type="text/javascript" src="/index.js">
</script>
</body>
</html>
`.trim()}
const get_siteroot_js = async () => {
const { stdout } = await execFile('/home/user/bun-linux-x64/bun', ['build', 'index.js'], { cwd: '/home/user/mount/bunapp', encoding : 'utf8' })
const str_output = stdout.trim()
return str_output
}
const async_get_embeddings_for_model = async (model_key, documents) => {
if (!(model_key in MODELKEY_SERVER_MAP)) {
console.warn(`no matching model_key: ${model_key}`)
return []
}
const { endpoint, prefix } = MODELKEY_SERVER_MAP[model_key]
let endpoint_content
if (prefix === undefined) {
endpoint_content = [ ...documents ]
}
else {
endpoint_content = documents.map(x=>(prefix + x))
}
const resp = await fetch(endpoint, {
method: 'POST',
body: JSON.stringify({ content: endpoint_content }),
})
let resp_json = []
try {
resp_json = await resp.json()
}
catch (err) {
console.error(err)
}
return [ ...resp_json ]
}
const async_get_request_body = (req) => {
return new Promise((resolve) => {
const lst_chunks = []
req.on('data', function (chunk) {
lst_chunks.push(chunk)
return
})
req.on('end', function () {
const all_bytes = []
for (let chunk of lst_chunks) {
for (let int_byte of chunk) {
all_bytes.push(int_byte)
}
}
resolve(new Uint8Array(all_bytes))
return
})
return
})
}
const ensure_collection_exists = async (collection_name, model_dim) => {
const collections = await client.listCollections()
if (collections.collection_names.indexOf(collection_name) === -1) {
console.info(`creating collection : ${collection_name} ...`)
const schema = [
{
name: 'id',
description: 'ID field',
data_type: DataType.VarChar,
max_length: 32,
is_primary_key: true,
autoID: false,
},
{
name: 'vector',
description: 'Vector field',
data_type: DataType.FloatVector,
dim: model_dim,
},
]
const result = await client.createCollection({
collection_name,
fields: schema,
})
if (result.error_code === 'Success') {
console.info(`created collection : ${collection_name} .`)
//
console.info(`creating index ...`)
const r2 = await client.createIndex({
collection_name,
field_name: 'vector',
index_name: 'myindex_A',
})
if (r2.error_code === 'Success') {
console.info('created index .')
}
else {
console.error('failed to create index .')
}
//
}
else {
console.info('failed to create collection.')
}
}
else {
//console.info(`collection exists already: ${collection_name}`)
}
return
}
const get_modelkey_from_original = (og_collectionname, model_key) => `${og_collectionname}_${model_key}`
const async_get_documents_embeddings_then_insert_into_collection = async (model_key, model_dim, og_collectionname, documents) => {
// make stable key for model-specific Collection
const collection_name = get_modelkey_from_original(og_collectionname, model_key)
await ensure_collection_exists(collection_name, model_dim)
const result = await async_get_embeddings_for_model(model_key, documents)
const fields_data = result.map(({ index, embedding }) => {
const vector = embedding[0]
const og_document = documents[index]
// create stable key (used as primary key) for doc-specific Embedding
// NOTE: hash the Original Document, so that Client can always have (locally-computable) id to it
// - (i.e. intentionally not storing hash of "prefix+document")
const md5_doc = crypto.createHash('md5').update(og_document).digest('hex')
return { id: md5_doc, vector }
})
//console.info(`upserting data : ${fields_data.length} ...`)
const result_upsert = await client.upsert({
collection_name,
fields_data
})
if (result_upsert.status.error_code === 'Success') {
//console.info('upsert succeeded .')
}
else {
console.warn(['upsert failed', { collection_name }])
}
return { finished: true }
}
const LOADED_COLLECTIONS = new Set()
const ensure_collection_loaded = async (internal_collectionname) => {
if (!LOADED_COLLECTIONS.has(internal_collectionname)) {
console.info(`loading collection : ${internal_collectionname} ...`)
await client.loadCollection({ collection_name: internal_collectionname })
console.info('loaded collection .')
LOADED_COLLECTIONS.add(internal_collectionname)
}
return
}
const get_searchresults_from_params = async (j_params) => {
const { query, collection_name, limit } = j_params
const async_getresult = async (model_key) => {
const embed_result = await async_get_embeddings_for_model(model_key, [ query ])
const vector = embed_result[0].embedding[0]
const internal_collectionname = get_modelkey_from_original(collection_name, model_key)
let search_result = []
try {
await ensure_collection_loaded(internal_collectionname)
const full_result = await client.search({
collection_name: internal_collectionname,
data: vector,
limit,
consistency_level: 'Eventually',
output_fields: ['id'],
})
search_result = full_result.results
}
catch (err) {
console.error(err)
// could not load collection, maybe does not exist
// (or, .search failed)
}
return { model_key, search_result }
}
const lst_promises = []
for (let model_key of Object.keys(MODELKEY_SERVER_MAP)) {
lst_promises.push(async_getresult(model_key))
}
const results = await Promise.all(lst_promises)
return results
}
const wrap_handler_error = async (async_handlerfunc) => {
let result
let did_error = false
try {
result = await async_handlerfunc()
}
catch (err) {
console.error(err)
did_error = true
}
return [ did_error, result ]
}
const func_dohandle = async (request, res) => {
const { method, headers, body } = request
const url_parts = url.parse(request.url)
const pathname = url_parts.pathname
let func_handle_res
if (method === 'GET' && pathname === '/') {
const str_html = await get_siteroot_html()
func_handle_res = () => {
res.writeHead(200, { 'Content-Type': 'text/html' })
res.write(str_html)
res.end()
}
}
else if (method === 'GET' && pathname === '/index.js') {
const str_js = await get_siteroot_js()
func_handle_res = () => {
res.writeHead(200, { 'Content-Type': 'text/javascript' })
res.write(str_js)
res.end()
}
}
else if (pathname === '/search') {
let j_params
if (method === 'GET') {
const searchParams = new URLSearchParams(url_parts.query)
const str_jparams = searchParams.get('p')
j_params = JSON.parse(str_jparams)
}
else if (method === 'POST') {
const body = await async_get_request_body(request)
j_params = JSON.parse(new TextDecoder().decode(body))
}
if (j_params !== undefined) {
const results = await get_searchresults_from_params(j_params)
func_handle_res = () => {
res.writeHead(200, { 'Content-Type': 'application/json' })
res.write(JSON.stringify({ results }))
res.end()
}
}
// else: fallback to default Not Found
}
else if (method === 'POST' && pathname === '/stash') {
const body = await async_get_request_body(request)
const j_body = JSON.parse(new TextDecoder().decode(body))
const { collection_name, documents, wait_for_result } = j_body
const lst_promises = []
for (let [ model_key, { dim } ] of Object.entries(MODELKEY_SERVER_MAP)) {
const model_dim = dim
lst_promises.push(async_get_documents_embeddings_then_insert_into_collection(model_key, model_dim, collection_name, documents))
}
let results_out = null
if (wait_for_result === true) {
try {
results_out = await Promise.all(lst_promises)
}
catch (err) {
console.error(err)
}
}
else {
Promise.all(lst_promises).then(()=>{}).catch(console.error)
}
func_handle_res = () => {
res.writeHead(200, { 'Content-Type': 'application/json' })
res.write(JSON.stringify({ results: results_out }))
res.end()
}
}
return { func_handle_res }
}
const server = http.createServer(async (request, res) => {
const [ did_error, handler_result ] = await wrap_handler_error(func_dohandle.bind(null, request, res))
if (did_error) {
res.writeHead(404)
res.end()
}
else if (handler_result.func_handle_res === undefined) {
res.writeHead(404)
res.end()
}
else {
handler_result.func_handle_res()
}
return
})
server.listen(PORT, HOST, () => {
console.log(`Server is running on http://${HOST}:${PORT}`)
return
})
function do_shutdown() {
console.log('graceful shutdown')
server.close(() => {
console.info('server closed.')
return
})
return
}
process.on('SIGINT', do_shutdown)