""" Inference runner for dual Qwen models. """ from collections import defaultdict import json import pandas as pd import requests from typing import List, Dict, TypedDict, cast from pathlib import Path import logging from tqdm import tqdm import sqlite3 from json_repair import repair_json, loads from pipeline.common_defs import Chunk, Message class InferenceRunner: """Run inference on dual Qwen models""" def __init__( self, batch_name: str, qwen3_url: str = "http://localhost:8000", qwen25_url: str = "http://localhost:8001", output_dir: str = "./pipeline_output", ): self.batch_name = batch_name self.qwen3_url = qwen3_url self.qwen25_url = qwen25_url self.output_dir = Path(output_dir) self.logger = logging.getLogger("InferenceRunner") self.logger.setLevel(logging.INFO) if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( "%(levelname)s - [%(filename)s:%(lineno)d] %(message)s" ) handler.setFormatter(formatter) self.logger.addHandler(handler) self.db = sqlite3.connect( self.output_dir / f"{batch_name}_processed.db3", check_same_thread=False, ) self.cursor = self.db.cursor() sql = """ CREATE TABLE IF NOT EXISTS processed ( chunk_id INTEGER, model_name: TEXT, message_index INTEGER, timestamp DATETIME, sender: TEXT, message TEXT, responsive BOOLEAN, reason TEXT, criteria TEXT, confidence TEXT, PRIMARY KEY (model_name, message_index) ); """ self.cursor.execute(sql) self.db.row_factory = sqlite3.Row self.logger.info("summary database initialized") def _create_user_prompt(self, chunk: Chunk) -> str: """Create inference request for a chunk""" # Format messages messages_text = "" for msg in chunk.messages: messages_text += ( f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message}\n" ) # Create full prompt prompt = f""" Review and classify the following messages. MESSAGES TO REVIEW (Lines {chunk.start_line}-{chunk.end_line}): {messages_text} Provide your response as valid JSON following the specified format. """ return prompt def _create_chunks(self) -> list[Chunk]: with open("pipeline_output/chunks.json", "r") as f: chunk_data = json.load(f) msg_df = pd.read_csv(self.output_dir / "preprocessed_messages.csv") # Reconstruct chunks (simplified) chunks = [] for item in chunk_data["filtered_chunks"][:10]: # First 10 for testing chunk = Chunk( chunk_id=item["chunk_id"], start_line=item["start_line"], end_line=item["end_line"], messages=[], combined_text="", timestamp_start=item["timestamp_start"], timestamp_end=item["timetamp_end"], ) chunk_messages = [] dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"] - 1] for index, row in dfRange.itertuples(): message = Message( (index + 1), row["timestamp"], row["sender"], row["message_normalized"], ) chunks.append(chunk) return chunks def run_inference(self, temperature: float = 0.1, max_tokens: int = 2048): """Run inference on both models""" self.logger.info("=" * 80) self.logger.info("RUNNING DUAL QWEN INFERENCE") self.logger.info("=" * 80) chunks = self._create_chunks() self.logger.info("\nRunning Qwen 3 235B inference...") self._run_model_inference( chunks, self.qwen3_url, "Qwen3-235B", temperature, max_tokens ) self.logger.info("\nRunning Qwen 2.5 72B inference...") self._run_model_inference( chunks, self.qwen25_url, "Qwen2.5-72B", temperature, max_tokens ) self.logger.info("\n" + "=" * 80) self.logger.info("INFERENCE COMPLETE") self.logger.info("=" * 80) def _create_system_prompt(self) -> str: """Create system prompt for LLM""" prompt = "" with Path(self.output_dir, "system_prompt.txt").open("r") as file: prompt = file.read() return prompt def _create_response_format(self) -> str: """Create response format for LLM""" response_format = "" with Path(self.output_dir, "response_format.json").open() as file: response_format = file.read() return response_format def _check_existing_result(self, chunk: Chunk, model_name) -> bool: """Check if result already exists in db""" sql = """ SELECT COUNT(*) AS num_messages FROM processed WHERE model_name = ? AND chunk_id = ? AND responsive IS NOT NULL """ result = self.cursor.execute(sql, (model_name, chunk.chunk_id)) row: dict = self.cursor.fetchone() if row and row["0"] == len(chunk.messages): return True return False def _save_result(self, chunk: Chunk, results: list[dict], model_name: str): """Save result to db""" # merge the chunk messages with the results merged_results = {} for msg in chunk.messages: merged_results[msg.line_number] = {"message": msg} for item in results: if item["message_index"] in merged_results: merged_results[item["message_index"]].update(item) else: merged_results[item["message_index"]] = item.copy() sql = """ INSERT INTO processed ( chunk_id, model_name, message_index, timestamp, sender, message, responsive, reason, criteria, confidence ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) ON CONFLICT (model_name, message_index) DO UPDATE SET chunk_id = excluded.chunk_id, timestamp = excluded.timestamp, sender = excluded.sender, message = excluded.message, responsive = excluded.responsive, reason = excluded.reason, criteria = excluded.criteria, confidence = excluded.confidence """ for result in merged_results: msg = result.get("message", None) if msg is None or not isinstance(msg, Message): self.logger.error( f"somehow we have a result without a message: \n{result}" ) continue self.cursor.execute( sql, ( result.get("chunk_id", None), model_name, msg.line_number, msg.timestamp, msg.sender, msg.message, result.get("responsive", None), result.get("reason", None), result.get("criteria", None), ), ) def _run_model_inference( self, chunks: List[Chunk], model_url: str, model_name: str, temperature: float, max_tokens: int, ): """Run inference on a single model""" system_prompt = self._create_system_prompt() response_format = self._create_response_format() success = 0 errors = 0 for chunk in tqdm(chunks, desc=f"{model_name} inference"): # check if this chunk has already been processed if self._check_existing_result(chunk, model_name): continue prompt_messages = [] prompt_messages.append({"role": "system", "content": system_prompt}) prompt_messages.append( {"role": "user", "content": self._create_user_prompt(chunk)} ) payload = { "model": model_name, "messages": prompt_messages, "temperature": temperature, "max_tokens": max_tokens, "response_format": { "type": "json_schema", "json_schema": { "name": "structured_response", "schema": json.loads(response_format), }, }, } # "top_p", # "top_k", # "frequency_penalty", # "presence_penalty", # # "stop", # # "skip_special_tokens", # "enable_thinking", headers = {"Content-Type": "application/json"} response = "Not Processed" try: response = requests.post( f"{model_url}/v1/completions", headers=headers, json=payload ) response.raise_for_status() # logger.log(LEVEL_TRACE, f"Response {response.status_code}\n{response.text}") data = response.json() if "error" in data: raise RuntimeError("LLM error") choices = data.get("choices", []) if not choices: raise KeyError("No choices found in response") first_choice = choices[0] if "message" in first_choice and first_choice["message"]: response_text = first_choice["message"].get("content", "") else: response_text = first_choice.get("text", "") if not response_text: raise ValueError("No response found") result = self._parse_response(response_text, chunk, model_name) if result: success += 1 else: raise RuntimeError("Could not parse result") except Exception as e: self.logger.error( f"Error processing chunk {chunk.chunk_id}: \nResponse was:\n{response}\n{e.with_traceback}" ) self._save_result(chunk, [], model_name) errors += 1 return success, errors def _parse_response( self, response_text, chunk: Chunk, model_name: str ) -> list[dict]: """Parse model response""" parsed_list = {} try: parsed = loads(response_text) parsed_list = cast(List[Dict], parsed) except Exception as e: self.logger.error(f"Errror parsing response for chunk {chunk.chunk_id}") if not parsed_list: return [] responses = [] for result in parsed_list: try: responses.append( { "chunk_id": chunk.chunk_id, "message_index": result.get("message_index", None), "responsive": result.get("responsive", None), "reason": result.get("reason", ""), "criteria": result.get("criteria", []), "confidence": result.get("confidence", "low"), } ) except Exception as e: self.logger.error( f"Error parsing response line: \n{e.with_traceback}\n{result}" ) return responses if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Run dual Qwen inference") parser.add_argument("batch_name") parser.add_argument("--qwen3-url", default="http://localhost:8001") parser.add_argument("--qwen25-url", default="http://localhost:8002") parser.add_argument("--output-dir", default="./pipeline_output") args = parser.parse_args() runner = InferenceRunner( args.batch_name, args.qwen3_url, args.qwen25_url, args.output_dir ) runner.run_inference()