288 lines
6.6 KiB
JavaScript
288 lines
6.6 KiB
JavaScript
import { async_run } from 'framerock'
|
|
|
|
import { mergeDeep } from './mergeDeep'
|
|
|
|
//============================================//
|
|
const [ path_templates, path_providers ] = Bun.argv.slice(2)
|
|
|
|
const file_templates = Bun.file(path_templates)
|
|
const file_providers = Bun.file(path_providers)
|
|
|
|
const PROMPT_TEMPLATES = await file_templates.json()
|
|
const PROVIDERS = await file_providers.json()
|
|
//============================================//
|
|
|
|
//============================================//
|
|
const do_serialize = function (serialization_method, val_in) {
|
|
let val_out
|
|
if (serialization_method === null) {
|
|
val_out = val_in
|
|
}
|
|
else if (serialization_method === 'json_stringify') {
|
|
val_out = JSON.stringify(val_in)
|
|
}
|
|
else {
|
|
throw `unexpected serialization_method : ${serialization_method}`
|
|
}
|
|
return val_out
|
|
}
|
|
|
|
const handle_fetch_fallback = function (req) {
|
|
|
|
let resp
|
|
|
|
const url = new URL(req.url)
|
|
|
|
if (
|
|
req.method === 'POST' &&
|
|
url.pathname === '/evaluate_template'
|
|
) {
|
|
|
|
resp = new Response(
|
|
async function* () {
|
|
|
|
try {
|
|
|
|
let messages
|
|
let response_format
|
|
|
|
//============================================//
|
|
const { prompt_template, merge_response_json_schema, prompt_append } = await req.json()
|
|
|
|
const def_template = PROMPT_TEMPLATES[prompt_template]
|
|
|
|
if (def_template !== undefined) {
|
|
|
|
let provider_uid
|
|
let endpoint
|
|
let model
|
|
let prompt
|
|
let messages
|
|
let response_format
|
|
|
|
//============================================//
|
|
const { provider_strategy, context_build, response_schema_json } = def_template
|
|
|
|
const [ strategy_type, strategy_opts ] = provider_strategy
|
|
|
|
if (strategy_type === 'strict') {
|
|
provider_uid = strategy_opts.provider_uid
|
|
endpoint = strategy_opts.endpoint
|
|
model = strategy_opts.model
|
|
}
|
|
else {
|
|
throw `unexpected strategy_type : ${strategy_type}`
|
|
}
|
|
//============================================//
|
|
|
|
const is_chat = (endpoint === '/v1/chat/completions')
|
|
|
|
//============================================//
|
|
let final_out_schema
|
|
|
|
if (
|
|
response_schema_json !== undefined &&
|
|
merge_response_json_schema !== undefined
|
|
) {
|
|
|
|
final_out_schema = mergeDeep(response_schema_json, merge_response_json_schema)
|
|
|
|
response_format = {
|
|
type: 'json_schema',
|
|
json_schema: {
|
|
strict: true,
|
|
schema: final_out_schema,
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
const { sections } = context_build
|
|
|
|
// for /chat/completions, maps to Messages
|
|
// NOTE: hardcoded "template" for /completions for compat
|
|
// - i.e. newline delimiters, after "role" and between "sections"
|
|
|
|
const final_sections = []
|
|
|
|
for (let section of sections) {
|
|
|
|
const { section_role, section_parts, parts_delimiter } = section
|
|
|
|
const parts_complete = []
|
|
|
|
for (let [ op_type, str_before, serialization_method, str_after ] of section_parts) {
|
|
if (op_type === 'include_output_format') {
|
|
parts_complete.push(str_before + do_serialize(serialization_method, final_out_schema) + str_after)
|
|
}
|
|
else if (op_type === 'include_instruction') {
|
|
parts_complete.push(str_before + do_serialize(serialization_method, prompt_append) + str_after)
|
|
}
|
|
else {
|
|
throw `unexpected op_type : ${op_type}`
|
|
}
|
|
}
|
|
|
|
const str_final = parts_complete.join(parts_delimiter)
|
|
|
|
if (is_chat === false) {
|
|
if (section_role !== null) {
|
|
final_sections.push(section_role + '\n' + str_final)
|
|
}
|
|
else {
|
|
final_sections.push(str_final)
|
|
}
|
|
}
|
|
else {
|
|
final_sections.push({ role: section_role, content: str_final })
|
|
}
|
|
|
|
}
|
|
|
|
if (is_chat) {
|
|
messages = [...final_sections]
|
|
}
|
|
else {
|
|
prompt = final_sections.join('\n')
|
|
}
|
|
|
|
//============================================//
|
|
|
|
//============================================//
|
|
// NOTE: have to send dummy char every so often to prevent connection timeout
|
|
|
|
let result_obj
|
|
let did_error
|
|
|
|
const final_params = {
|
|
model,
|
|
prompt,
|
|
messages,
|
|
response_format
|
|
}
|
|
|
|
wrap_api_fetch(
|
|
endpoint,
|
|
provider_uid,
|
|
final_params,
|
|
).then((resp_json) => {
|
|
|
|
if (resp_json === undefined) {
|
|
did_error = true
|
|
}
|
|
else {
|
|
if (is_chat === false) {
|
|
result_obj = JSON.parse(resp_json.choices[0].text)
|
|
}
|
|
else {
|
|
result_obj = JSON.parse(resp_json.choices[0].message.content)
|
|
}
|
|
}
|
|
return
|
|
}).catch(console.error)
|
|
|
|
let last_yield = Date.now()
|
|
|
|
while (
|
|
result_obj === undefined &&
|
|
did_error === undefined
|
|
) {
|
|
await Bun.sleep(100)
|
|
if ((Date.now() - last_yield) >= (5 * 1000)) {
|
|
yield ' '
|
|
last_yield = Date.now()
|
|
}
|
|
}
|
|
|
|
if (did_error !== undefined) {
|
|
yield JSON.stringify({ 'ERROR': true })
|
|
}
|
|
else {
|
|
yield JSON.stringify(result_obj)
|
|
}
|
|
//============================================//
|
|
|
|
}
|
|
else {
|
|
throw `unexpected prompt_template : ${prompt_template}`
|
|
}
|
|
//============================================//
|
|
|
|
}
|
|
catch (err) {
|
|
|
|
console.error(err)
|
|
|
|
yield JSON.stringify({ 'ERROR': true })
|
|
|
|
}
|
|
|
|
return
|
|
},
|
|
{ headers: { 'Content-Type': 'application/json' } },
|
|
)
|
|
|
|
}
|
|
|
|
return resp
|
|
|
|
}
|
|
//============================================//
|
|
|
|
//============================================//
|
|
const wrap_api_fetch = async (str_endpoint, uid, params) => {
|
|
|
|
let submsg_resp
|
|
|
|
const body = JSON.stringify(params)
|
|
|
|
const provider_info = PROVIDERS[uid]
|
|
|
|
if (provider_info === undefined) {
|
|
|
|
console.error(`no matching provider : ${uid}`)
|
|
|
|
}
|
|
else {
|
|
|
|
const { url, auth_token } = provider_info
|
|
|
|
const full_url = url + str_endpoint
|
|
|
|
let headers
|
|
if (auth_token !== undefined) {
|
|
headers = {'Authorization': `Bearer ${auth_token}`}
|
|
}
|
|
|
|
const method = 'POST'
|
|
|
|
const tstamp_before = Date.now()
|
|
const resp = await fetch(
|
|
full_url,
|
|
{
|
|
method,
|
|
headers: { 'Content-Type': 'application/json', ...headers },
|
|
body,
|
|
}
|
|
)
|
|
const tstamp_after = Date.now()
|
|
|
|
const resp_text = await resp.text()
|
|
if (resp.status === 200) {
|
|
// NOTE: assumes valid JSON
|
|
submsg_resp = JSON.parse(resp_text)
|
|
}
|
|
else {
|
|
console.error(['unexpected status', resp.status])
|
|
// DEBUG
|
|
console.info({ resp_text })
|
|
}
|
|
|
|
}
|
|
|
|
return submsg_resp
|
|
|
|
}
|
|
//============================================//
|
|
|
|
async_run({ handle_fetch_fallback }).then(()=>{}).catch(console.error)
|