Version 0.2.0 (see changelog for details)
This commit is contained in:
parent
42409c3d5a
commit
f9ef3582b5
3 changed files with 120 additions and 41 deletions
|
|
@ -19,3 +19,8 @@ Run server:
|
|||
```bash
|
||||
bun run aiprox/backend/index.js $PATH_TO_TEMPLATES $PATH_TO_PROVIDERS
|
||||
```
|
||||
|
||||
## changelog
|
||||
|
||||
- **Version 0.2.0**
|
||||
- improved templating logic
|
||||
|
|
|
|||
149
backend/index.js
149
backend/index.js
|
|
@ -13,6 +13,20 @@ const PROVIDERS = await file_providers.json()
|
|||
//============================================//
|
||||
|
||||
//============================================//
|
||||
const do_serialize = function (serialization_method, val_in) {
|
||||
let val_out
|
||||
if (serialization_method === null) {
|
||||
val_out = val_in
|
||||
}
|
||||
else if (serialization_method === 'json_stringify') {
|
||||
val_out = JSON.stringify(val_in)
|
||||
}
|
||||
else {
|
||||
throw `unexpected serialization_method : ${serialization_method}`
|
||||
}
|
||||
return val_out
|
||||
}
|
||||
|
||||
const handle_fetch_fallback = function (req) {
|
||||
|
||||
let resp
|
||||
|
|
@ -31,76 +45,138 @@ const handle_fetch_fallback = function (req) {
|
|||
|
||||
let messages
|
||||
let response_format
|
||||
|
||||
//============================================//
|
||||
const { prompt_template, template_parts, merge_schema } = await req.json()
|
||||
const { prompt_template, merge_response_json_schema, prompt_append } = await req.json()
|
||||
|
||||
const def_template = PROMPT_TEMPLATES[prompt_template]
|
||||
|
||||
if (def_template !== undefined) {
|
||||
|
||||
let provider_uid
|
||||
let endpoint
|
||||
let model
|
||||
let prompt
|
||||
let messages
|
||||
let response_format
|
||||
|
||||
//============================================//
|
||||
const { provider_model_strategy, schema, template } = def_template
|
||||
const { provider_strategy, context_build, response_schema_json } = def_template
|
||||
|
||||
const [ strategy_type, strategy_opts ] = provider_model_strategy
|
||||
const [ strategy_type, strategy_opts ] = provider_strategy
|
||||
|
||||
if (strategy_type === 'strict') {
|
||||
provider_uid = strategy_opts.provider_uid
|
||||
endpoint = strategy_opts.endpoint
|
||||
model = strategy_opts.model
|
||||
}
|
||||
else {
|
||||
throw `unexpected strategy_type : ${strategy_type}`
|
||||
}
|
||||
//============================================//
|
||||
|
||||
const json_schema = mergeDeep(merge_schema, schema)
|
||||
const template_slots = [ '<RESPONSE_JSON_SCHEMA>\n' + JSON.stringify(json_schema) + '\n</RESPONSE_JSON_SCHEMA>\n', ...template ]
|
||||
const is_chat = (endpoint === '/v1/chat/completions')
|
||||
|
||||
const copy_parts = [...template_parts]
|
||||
//============================================//
|
||||
let final_out_schema
|
||||
|
||||
const full_parts = []
|
||||
for (let x of template_slots) {
|
||||
if (x !== null) {
|
||||
full_parts.push(x)
|
||||
}
|
||||
else {
|
||||
full_parts.push(copy_parts.shift())
|
||||
}
|
||||
}
|
||||
if (
|
||||
response_schema_json !== undefined &&
|
||||
merge_response_json_schema !== undefined
|
||||
) {
|
||||
|
||||
const content = full_parts.join('\n')
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
]
|
||||
final_out_schema = mergeDeep(response_schema_json, merge_response_json_schema)
|
||||
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "chat_response",
|
||||
"strict": true,
|
||||
"schema": json_schema,
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
strict: true,
|
||||
schema: final_out_schema,
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
const { sections } = context_build
|
||||
|
||||
// for /chat/completions, maps to Messages
|
||||
// NOTE: hardcoded "template" for /completions for compat
|
||||
// - i.e. newline delimiters, after "role" and between "sections"
|
||||
|
||||
const final_sections = []
|
||||
|
||||
for (let section of sections) {
|
||||
|
||||
const { section_role, section_parts, parts_delimiter } = section
|
||||
|
||||
const parts_complete = []
|
||||
|
||||
for (let [ op_type, str_before, serialization_method, str_after ] of section_parts) {
|
||||
if (op_type === 'include_output_format') {
|
||||
parts_complete.push(str_before + do_serialize(serialization_method, final_out_schema) + str_after)
|
||||
}
|
||||
else if (op_type === 'include_instruction') {
|
||||
parts_complete.push(str_before + do_serialize(serialization_method, prompt_append) + str_after)
|
||||
}
|
||||
else {
|
||||
throw `unexpected op_type : ${op_type}`
|
||||
}
|
||||
}
|
||||
|
||||
const str_final = parts_complete.join(parts_delimiter)
|
||||
|
||||
if (is_chat === false) {
|
||||
if (section_role !== null) {
|
||||
final_sections.push(section_role + '\n' + str_final)
|
||||
}
|
||||
else {
|
||||
final_sections.push(str_final)
|
||||
}
|
||||
}
|
||||
else {
|
||||
final_sections.push({ role: section_role, content: str_final })
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (is_chat) {
|
||||
messages = [...final_sections]
|
||||
}
|
||||
else {
|
||||
prompt = final_sections.join('\n')
|
||||
}
|
||||
|
||||
//============================================//
|
||||
|
||||
//============================================//
|
||||
// !! NOTE: have to send dummy char every so often to prevent connection timeout !!
|
||||
// NOTE: have to send dummy char every so often to prevent connection timeout
|
||||
|
||||
let result_obj
|
||||
let did_error
|
||||
|
||||
handle_chat_completion(provider_uid, model, { messages, response_format }).then((resp_json) => {
|
||||
const final_params = {
|
||||
model,
|
||||
prompt,
|
||||
messages,
|
||||
response_format
|
||||
}
|
||||
|
||||
wrap_api_fetch(
|
||||
endpoint,
|
||||
provider_uid,
|
||||
final_params,
|
||||
).then((resp_json) => {
|
||||
|
||||
if (resp_json === undefined) {
|
||||
did_error = true
|
||||
}
|
||||
else {
|
||||
const result_str = resp_json.choices[0].message.content
|
||||
result_obj = JSON.parse(result_str)
|
||||
if (is_chat === false) {
|
||||
result_obj = JSON.parse(resp_json.choices[0].text)
|
||||
}
|
||||
else {
|
||||
result_obj = JSON.parse(resp_json.choices[0].message.content)
|
||||
}
|
||||
}
|
||||
return
|
||||
}).catch(console.error)
|
||||
|
|
@ -154,15 +230,10 @@ const handle_fetch_fallback = function (req) {
|
|||
//============================================//
|
||||
|
||||
//============================================//
|
||||
const handle_chat_completion = async (uid, model, { messages, response_format }) => {
|
||||
const wrap_api_fetch = async (str_endpoint, uid, params) => {
|
||||
|
||||
let submsg_resp
|
||||
|
||||
const params = {
|
||||
model,
|
||||
messages,
|
||||
response_format,
|
||||
}
|
||||
const body = JSON.stringify(params)
|
||||
|
||||
const provider_info = PROVIDERS[uid]
|
||||
|
|
@ -176,7 +247,7 @@ const handle_chat_completion = async (uid, model, { messages, response_format })
|
|||
|
||||
const { url, auth_token } = provider_info
|
||||
|
||||
const full_url = url + '/v1/chat/completions'
|
||||
const full_url = url + str_endpoint
|
||||
|
||||
let headers
|
||||
if (auth_token !== undefined) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
{
|
||||
"name": "aiprox",
|
||||
"version": "0.1.0",
|
||||
"version": "0.2.0",
|
||||
"type": "module",
|
||||
"main": "backend/index.js"
|
||||
"main": "backend/index.js",
|
||||
"dependencies": {
|
||||
"framerock": "git+https://git.daemons.my/dab/framerock.git"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue