aiprox/backend/index.js

217 lines
4.8 KiB
JavaScript

import { async_run } from 'framerock'
import { mergeDeep } from './mergeDeep'
//============================================//
const [ path_templates, path_providers ] = Bun.argv.slice(2)
const file_templates = Bun.file(path_templates)
const file_providers = Bun.file(path_providers)
const PROMPT_TEMPLATES = await file_templates.json()
const PROVIDERS = await file_providers.json()
//============================================//
//============================================//
const handle_fetch_fallback = function (req) {
let resp
const url = new URL(req.url)
if (
req.method === 'POST' &&
url.pathname === '/evaluate_template'
) {
resp = new Response(
async function* () {
try {
let messages
let response_format
//============================================//
const { prompt_template, template_parts, merge_schema } = await req.json()
const def_template = PROMPT_TEMPLATES[prompt_template]
if (def_template !== undefined) {
let provider_uid
let model
//============================================//
const { provider_model_strategy, schema, template } = def_template
const [ strategy_type, strategy_opts ] = provider_model_strategy
if (strategy_type === 'strict') {
provider_uid = strategy_opts.provider_uid
model = strategy_opts.model
}
else {
throw `unexpected strategy_type : ${strategy_type}`
}
const json_schema = mergeDeep(merge_schema, schema)
const template_slots = [ '<RESPONSE_JSON_SCHEMA>\n' + JSON.stringify(json_schema) + '\n</RESPONSE_JSON_SCHEMA>\n', ...template ]
const copy_parts = [...template_parts]
const full_parts = []
for (let x of template_slots) {
if (x !== null) {
full_parts.push(x)
}
else {
full_parts.push(copy_parts.shift())
}
}
const content = full_parts.join('\n')
messages = [
{
"role": "user",
"content": content,
}
]
response_format = {
"type": "json_schema",
"json_schema": {
"name": "chat_response",
"strict": true,
"schema": json_schema,
}
}
//============================================//
//============================================//
// !! NOTE: have to send dummy char every so often to prevent connection timeout !!
let result_obj
let did_error
handle_chat_completion(provider_uid, model, { messages, response_format }).then((resp_json) => {
if (resp_json === undefined) {
did_error = true
}
else {
const result_str = resp_json.choices[0].message.content
result_obj = JSON.parse(result_str)
}
return
}).catch(console.error)
let last_yield = Date.now()
while (
result_obj === undefined &&
did_error === undefined
) {
await Bun.sleep(100)
if ((Date.now() - last_yield) >= (5 * 1000)) {
yield ' '
last_yield = Date.now()
}
}
if (did_error !== undefined) {
yield JSON.stringify({ 'ERROR': true })
}
else {
yield JSON.stringify(result_obj)
}
//============================================//
}
else {
throw `unexpected prompt_template : ${prompt_template}`
}
//============================================//
}
catch (err) {
console.error(err)
yield JSON.stringify({ 'ERROR': true })
}
return
},
{ headers: { 'Content-Type': 'application/json' } },
)
}
return resp
}
//============================================//
//============================================//
const handle_chat_completion = async (uid, model, { messages, response_format }) => {
let submsg_resp
const params = {
model,
messages,
response_format,
}
const body = JSON.stringify(params)
const provider_info = PROVIDERS[uid]
if (provider_info === undefined) {
console.error(`no matching provider : ${uid}`)
}
else {
const { url, auth_token } = provider_info
const full_url = url + '/v1/chat/completions'
let headers
if (auth_token !== undefined) {
headers = {'Authorization': `Bearer ${auth_token}`}
}
const method = 'POST'
const tstamp_before = Date.now()
const resp = await fetch(
full_url,
{
method,
headers: { 'Content-Type': 'application/json', ...headers },
body,
}
)
const tstamp_after = Date.now()
const resp_text = await resp.text()
if (resp.status === 200) {
// NOTE: assumes valid JSON
submsg_resp = JSON.parse(resp_text)
}
else {
console.error(['unexpected status', resp.status])
// DEBUG
console.info({ resp_text })
}
}
return submsg_resp
}
//============================================//
async_run({ handle_fetch_fallback }).then(()=>{}).catch(console.error)