| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371 |
- """
- Inference runner for dual Qwen models.
- """
- from collections import defaultdict
- import json
- import pandas as pd
- import requests
- from typing import List, Dict, TypedDict, cast
- from pathlib import Path
- import logging
- from tqdm import tqdm
- import sqlite3
- from json_repair import repair_json, loads
- from pipeline.common_defs import Chunk, Message
- class InferenceRunner:
- """Run inference on dual Qwen models"""
- def __init__(
- self,
- batch_name: str,
- qwen3_url: str = "http://localhost:8000",
- qwen25_url: str = "http://localhost:8001",
- output_dir: str = "./pipeline_output",
- ):
- self.batch_name = batch_name
- self.qwen3_url = qwen3_url
- self.qwen25_url = qwen25_url
- self.output_dir = Path(output_dir)
- self.logger = logging.getLogger("InferenceRunner")
- self.logger.setLevel(logging.INFO)
- if not self.logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter(
- "%(levelname)s - [%(filename)s:%(lineno)d] %(message)s"
- )
- handler.setFormatter(formatter)
- self.logger.addHandler(handler)
- self.db = sqlite3.connect(
- self.output_dir / f"{batch_name}_processed.db3",
- check_same_thread=False,
- )
- self.cursor = self.db.cursor()
- sql = """
- CREATE TABLE IF NOT EXISTS processed (
- chunk_id INTEGER,
- model_name: TEXT,
- message_index INTEGER,
- timestamp DATETIME,
- sender: TEXT,
- message TEXT,
- responsive BOOLEAN,
- reason TEXT,
- criteria TEXT,
- confidence TEXT,
- PRIMARY KEY (model_name, message_index)
- );
- """
- self.cursor.execute(sql)
- self.db.row_factory = sqlite3.Row
- self.logger.info("summary database initialized")
- def _create_user_prompt(self, chunk: Chunk) -> str:
- """Create inference request for a chunk"""
- # Format messages
- messages_text = ""
- for msg in chunk.messages:
- messages_text += (
- f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message}\n"
- )
- # Create full prompt
- prompt = f"""
- Review and classify the following messages.
- MESSAGES TO REVIEW (Lines {chunk.start_line}-{chunk.end_line}):
- {messages_text}
- Provide your response as valid JSON following the specified format.
- """
- return prompt
- def _create_chunks(self) -> list[Chunk]:
- with open("pipeline_output/chunks.json", "r") as f:
- chunk_data = json.load(f)
- msg_df = pd.read_csv(self.output_dir / "preprocessed_messages.csv")
- # Reconstruct chunks (simplified)
- chunks = []
- for item in chunk_data["filtered_chunks"][:10]: # First 10 for testing
- chunk = Chunk(
- chunk_id=item["chunk_id"],
- start_line=item["start_line"],
- end_line=item["end_line"],
- messages=[],
- combined_text="",
- timestamp_start=item["timestamp_start"],
- timestamp_end=item["timetamp_end"],
- )
- chunk_messages = []
- dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"] - 1]
- for index, row in dfRange.itertuples():
- message = Message(
- (index + 1),
- row["timestamp"],
- row["sender"],
- row["message_normalized"],
- )
- chunks.append(chunk)
- return chunks
- def run_inference(self, temperature: float = 0.1, max_tokens: int = 2048):
- """Run inference on both models"""
- self.logger.info("=" * 80)
- self.logger.info("RUNNING DUAL QWEN INFERENCE")
- self.logger.info("=" * 80)
- chunks = self._create_chunks()
- self.logger.info("\nRunning Qwen 3 235B inference...")
- self._run_model_inference(
- chunks, self.qwen3_url, "Qwen3-235B", temperature, max_tokens
- )
- self.logger.info("\nRunning Qwen 2.5 72B inference...")
- self._run_model_inference(
- chunks, self.qwen25_url, "Qwen2.5-72B", temperature, max_tokens
- )
- self.logger.info("\n" + "=" * 80)
- self.logger.info("INFERENCE COMPLETE")
- self.logger.info("=" * 80)
- def _create_system_prompt(self) -> str:
- """Create system prompt for LLM"""
- prompt = ""
- with Path(self.output_dir, "system_prompt.txt").open("r") as file:
- prompt = file.read()
- return prompt
- def _create_response_format(self) -> str:
- """Create response format for LLM"""
- response_format = ""
- with Path(self.output_dir, "response_format.json").open() as file:
- response_format = file.read()
- return response_format
- def _check_existing_result(self, chunk: Chunk, model_name) -> bool:
- """Check if result already exists in db"""
- sql = """
- SELECT
- COUNT(*) AS num_messages
- FROM
- processed
- WHERE
- model_name = ?
- AND chunk_id = ?
- AND responsive IS NOT NULL
- """
- result = self.cursor.execute(sql, (model_name, chunk.chunk_id))
- row: dict = self.cursor.fetchone()
- if row and row["0"] == len(chunk.messages):
- return True
- return False
- def _save_result(self, chunk: Chunk, results: list[dict], model_name: str):
- """Save result to db"""
- # merge the chunk messages with the results
- merged_results = {}
- for msg in chunk.messages:
- merged_results[msg.line_number] = {"message": msg}
- for item in results:
- if item["message_index"] in merged_results:
- merged_results[item["message_index"]].update(item)
- else:
- merged_results[item["message_index"]] = item.copy()
- sql = """
- INSERT INTO processed (
- chunk_id,
- model_name,
- message_index,
- timestamp,
- sender,
- message,
- responsive,
- reason,
- criteria,
- confidence
- ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ? )
- ON CONFLICT (model_name, message_index) DO UPDATE SET
- chunk_id = excluded.chunk_id,
- timestamp = excluded.timestamp,
- sender = excluded.sender,
- message = excluded.message,
- responsive = excluded.responsive,
- reason = excluded.reason,
- criteria = excluded.criteria,
- confidence = excluded.confidence
- """
- for result in merged_results:
- msg = result.get("message", None)
- if msg is None or not isinstance(msg, Message):
- self.logger.error(
- f"somehow we have a result without a message: \n{result}"
- )
- continue
- self.cursor.execute(
- sql,
- (
- result.get("chunk_id", None),
- model_name,
- msg.line_number,
- msg.timestamp,
- msg.sender,
- msg.message,
- result.get("responsive", None),
- result.get("reason", None),
- result.get("criteria", None),
- ),
- )
- def _run_model_inference(
- self,
- chunks: List[Chunk],
- model_url: str,
- model_name: str,
- temperature: float,
- max_tokens: int,
- ):
- """Run inference on a single model"""
- system_prompt = self._create_system_prompt()
- response_format = self._create_response_format()
- success = 0
- errors = 0
- for chunk in tqdm(chunks, desc=f"{model_name} inference"):
- # check if this chunk has already been processed
- if self._check_existing_result(chunk, model_name):
- continue
- prompt_messages = []
- prompt_messages.append({"role": "system", "content": system_prompt})
- prompt_messages.append(
- {"role": "user", "content": self._create_user_prompt(chunk)}
- )
- payload = {
- "model": model_name,
- "messages": prompt_messages,
- "temperature": temperature,
- "max_tokens": max_tokens,
- "response_format": {
- "type": "json_schema",
- "json_schema": {
- "name": "structured_response",
- "schema": json.loads(response_format),
- },
- },
- }
- # "top_p",
- # "top_k",
- # "frequency_penalty",
- # "presence_penalty",
- # # "stop",
- # # "skip_special_tokens",
- # "enable_thinking",
- headers = {"Content-Type": "application/json"}
- response = "Not Processed"
- try:
- response = requests.post(
- f"{model_url}/v1/completions", headers=headers, json=payload
- )
- response.raise_for_status()
- # logger.log(LEVEL_TRACE, f"Response {response.status_code}\n{response.text}")
- data = response.json()
- if "error" in data:
- raise RuntimeError("LLM error")
- choices = data.get("choices", [])
- if not choices:
- raise KeyError("No choices found in response")
- first_choice = choices[0]
- if "message" in first_choice and first_choice["message"]:
- response_text = first_choice["message"].get("content", "")
- else:
- response_text = first_choice.get("text", "")
- if not response_text:
- raise ValueError("No response found")
- result = self._parse_response(response_text, chunk, model_name)
- if result:
- success += 1
- else:
- raise RuntimeError("Could not parse result")
- except Exception as e:
- self.logger.error(
- f"Error processing chunk {chunk.chunk_id}: \nResponse was:\n{response}\n{e.with_traceback}"
- )
- self._save_result(chunk, [], model_name)
- errors += 1
- return success, errors
- def _parse_response(
- self, response_text, chunk: Chunk, model_name: str
- ) -> list[dict]:
- """Parse model response"""
- parsed_list = {}
- try:
- parsed = loads(response_text)
- parsed_list = cast(List[Dict], parsed)
- except Exception as e:
- self.logger.error(f"Errror parsing response for chunk {chunk.chunk_id}")
- if not parsed_list:
- return []
- responses = []
- for result in parsed_list:
- try:
- responses.append(
- {
- "chunk_id": chunk.chunk_id,
- "message_index": result.get("message_index", None),
- "responsive": result.get("responsive", None),
- "reason": result.get("reason", ""),
- "criteria": result.get("criteria", []),
- "confidence": result.get("confidence", "low"),
- }
- )
- except Exception as e:
- self.logger.error(
- f"Error parsing response line: \n{e.with_traceback}\n{result}"
- )
- return responses
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser(description="Run dual Qwen inference")
- parser.add_argument("batch_name")
- parser.add_argument("--qwen3-url", default="http://localhost:8001")
- parser.add_argument("--qwen25-url", default="http://localhost:8002")
- parser.add_argument("--output-dir", default="./pipeline_output")
- args = parser.parse_args()
- runner = InferenceRunner(
- args.batch_name, args.qwen3_url, args.qwen25_url, args.output_dir
- )
- runner.run_inference()
|