diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 0af680bb5..3614272ef 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -54,6 +54,7 @@ def __init__( logger.info( f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." ) + self.llm_engine = llm_engine self.tokenizer = llm_engine.engine.tokenizer # This is to support vllm >= 0.2.7 where TokenizerGroup was introduced # and llm_engine.engine.tokenizer was no longer a raw tokenizer @@ -116,7 +117,7 @@ async def generate_stream(self, params): frequency_penalty=frequency_penalty, best_of=best_of, ) - results_generator = engine.generate(context, sampling_params, request_id) + results_generator = self.llm_engine.generate(context, sampling_params, request_id) async for request_output in results_generator: prompt = request_output.prompt @@ -135,7 +136,7 @@ async def generate_stream(self, params): aborted = False if request and await request.is_disconnected(): - await engine.abort(request_id) + await self.llm_engine.abort(request_id) request_output.finished = True aborted = True for output in request_output.outputs: