165 lines
6.1 KiB
Python
165 lines
6.1 KiB
Python
import argparse
|
||
import json
|
||
import os
|
||
import sys
|
||
|
||
__package__ = "scripts"
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||
import time
|
||
import torch
|
||
import warnings
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||
from model.LMConfig import LMConfig
|
||
from model.model import MiniMindLM
|
||
from model.model_lora import apply_lora, load_lora
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
app = FastAPI()
|
||
|
||
|
||
def init_model(args):
|
||
tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
|
||
if args.load == 0:
|
||
moe_path = '_moe' if args.use_moe else ''
|
||
modes = {0: 'pretrain', 1: 'full_sft', 2: 'rlhf', 3: 'reason'}
|
||
ckp = f'../{args.out_dir}/{modes[args.model_mode]}_{args.dim}{moe_path}.pth'
|
||
|
||
model = MiniMindLM(LMConfig(
|
||
dim=args.dim,
|
||
n_layers=args.n_layers,
|
||
max_seq_len=args.max_seq_len,
|
||
use_moe=args.use_moe
|
||
))
|
||
|
||
state_dict = torch.load(ckp, map_location=device)
|
||
model.load_state_dict({k: v for k, v in state_dict.items() if 'mask' not in k}, strict=True)
|
||
|
||
if args.lora_name != 'None':
|
||
apply_lora(model)
|
||
load_lora(model, f'../{args.out_dir}/{args.lora_name}_{args.dim}.pth')
|
||
else:
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
'./MiniMind2',
|
||
trust_remote_code=True
|
||
)
|
||
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
|
||
return model.eval().to(device), tokenizer
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
model: str
|
||
messages: list
|
||
temperature: float = 0.7
|
||
top_p: int = 0.92
|
||
max_tokens: int = 8192
|
||
stream: bool = False
|
||
|
||
|
||
def generate_stream_response(messages, temperature, top_p, max_tokens):
|
||
try:
|
||
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)[-max_tokens:]
|
||
x = tokenizer(new_prompt).data['input_ids']
|
||
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
|
||
with torch.no_grad():
|
||
res_y = model.generate(
|
||
x,
|
||
eos_token_id=tokenizer.eos_token_id,
|
||
max_new_tokens=max_tokens,
|
||
temperature=temperature,
|
||
top_p=top_p,
|
||
stream=True,
|
||
rp=1.,
|
||
pad_token_id=tokenizer.pad_token_id
|
||
)
|
||
history_idx = 0
|
||
for y in res_y:
|
||
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
|
||
if (answer and answer[-1] == '<EFBFBD>') or not answer:
|
||
continue
|
||
delta = answer[history_idx:]
|
||
history_idx = len(answer)
|
||
json_data = {
|
||
'id': f'chatcmpl-{int(time.time())}',
|
||
'object': 'chat.completion.chunk',
|
||
'created': int(time.time()),
|
||
'model': 'minimind',
|
||
'choices': [{'index': 0, 'delta': {'content': delta}, 'finish_reason': None}]
|
||
}
|
||
yield f"data: {json.dumps(json_data)}\n\n"
|
||
|
||
except Exception as e:
|
||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||
|
||
|
||
@app.post("/v1/chat/completions")
|
||
async def chat_completions(request: ChatRequest):
|
||
try:
|
||
if request.stream:
|
||
return StreamingResponse(
|
||
generate_stream_response(
|
||
messages=request.messages,
|
||
temperature=request.temperature,
|
||
top_p=request.top_p,
|
||
max_tokens=request.max_tokens
|
||
),
|
||
media_type="text/event-stream"
|
||
)
|
||
else:
|
||
new_prompt = tokenizer.apply_chat_template(
|
||
request.messages,
|
||
tokenize=False,
|
||
add_generation_prompt=True
|
||
)[-request.max_tokens:]
|
||
x = tokenizer(new_prompt).data['input_ids']
|
||
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
|
||
with torch.no_grad():
|
||
res_y = model.generate(
|
||
x,
|
||
eos_token_id=tokenizer.eos_token_id,
|
||
max_new_tokens=request.max_tokens,
|
||
temperature=request.temperature,
|
||
top_p=request.top_p,
|
||
stream=False,
|
||
rp=1.,
|
||
pad_token_id=tokenizer.pad_token_id
|
||
)
|
||
answer = tokenizer.decode(res_y.squeeze()[x.shape[1]:].tolist(), skip_special_tokens=True)
|
||
return {
|
||
"id": f"chatcmpl-{int(time.time())}",
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": "minimind",
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {"role": "assistant", "content": answer},
|
||
"finish_reason": "stop"
|
||
}
|
||
]
|
||
}
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="Server for MiniMind")
|
||
parser.add_argument('--out_dir', default='out', type=str)
|
||
parser.add_argument('--lora_name', default='None', type=str)
|
||
parser.add_argument('--dim', default=512, type=int)
|
||
parser.add_argument('--n_layers', default=8, type=int)
|
||
parser.add_argument('--max_seq_len', default=8192, type=int)
|
||
parser.add_argument('--use_moe', default=False, type=bool)
|
||
parser.add_argument('--load', default=0, type=int, help="0: 从原生torch权重,1: 利用transformers加载")
|
||
parser.add_argument('--model_mode', default=1, type=int, help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型")
|
||
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
model, tokenizer = init_model(parser.parse_args())
|
||
|
||
uvicorn.run(app, host="0.0.0.0", port=8998)
|