""" Improved Inference runner for dual Qwen models with parallel processing. """ import json import os import traceback import pandas as pd import requests from typing import List, Dict, cast from pathlib import Path import logging from tqdm import tqdm import sqlite3 from json_repair import loads import concurrent.futures import threading from pipeline.common_defs import Chunk, Message class InferenceRunner: """Run inference on dual Qwen models with parallel processing""" def __init__( self, batch_name: str, qwen3_model_name: str = "", qwen3_url: str = "http://localhost:8001", qwen25_model_name: str = "", qwen25_url: str = "http://localhost:8002", output_dir: str = "./pipeline_output", max_workers: int = 4, # New parameter for parallel processing batch_size: int = 4, temperature: float = 0.01, top_p: float = 0.95, max_tokens: int = 8192, ): self.batch_name = batch_name self.qwen3_model_name = qwen3_model_name self.qwen3_url = qwen3_url self.qwen25_model_name = qwen25_model_name self.qwen25_url = qwen25_url self.output_dir = Path(output_dir) self.max_workers = max_workers self.batch_size = batch_size self.max_tokens = max_tokens self.temperature = temperature self.top_p = top_p self.logger = logging.getLogger("InferenceRunner") self.logger.setLevel(logging.DEBUG) if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( "%(levelname)s - [%(filename)s:%(lineno)d] %(message)s" ) handler.setFormatter(formatter) self.logger.addHandler(handler) # Create output directory if it doesn't exist self.output_dir.mkdir(parents=True, exist_ok=True) self.db = sqlite3.connect( self.output_dir / f"{batch_name}_processed.db3", check_same_thread=False, ) self.db.row_factory = sqlite3.Row self.cursor = self.db.cursor() # Fixed SQL syntax error (removed colon after TEXT) sql = """ CREATE TABLE IF NOT EXISTS processed ( chunk_id INTEGER, model_name TEXT, message_index INTEGER, timestamp DATETIME, sender TEXT, message TEXT, responsive BOOLEAN, reason TEXT, criteria TEXT, confidence TEXT, PRIMARY KEY (model_name, message_index) ); """ self.cursor.execute(sql) self.db.commit() # Added commit self.logger.info("Summary database initialized") # Thread lock for database operations self.db_lock = threading.Lock() def _create_user_prompt(self, chunk: Chunk) -> str: """Create inference request for a chunk""" # Format messages messages_text = "" for msg in chunk.messages: messages_text += f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message_normalized}\n" # Create full prompt prompt = f""" Review and classify the following messages. MESSAGES TO REVIEW (Lines {chunk.start_line}-{chunk.end_line}): {messages_text} Provide your response as valid JSON following the specified format. """ return prompt def _create_chunks(self) -> list[Chunk]: """Create chunks from preprocessed data""" chunks_file = self.output_dir / "chunks.json" with open(chunks_file, "r") as f: chunk_data = json.load(f) msg_file = self.output_dir / "preprocessed_messages.csv" msg_df = pd.read_csv(msg_file) # Reconstruct chunks chunks = [] for item in chunk_data[:2]: # First 10 for testing chunk = Chunk( chunk_id=item["chunk_id"], start_line=item["start_line"], end_line=item["end_line"], messages=[], combined_text="", timestamp_start=item["timestamp_start"], timestamp_end=item.get("timestamp_end", ""), ) chunk_messages = [] dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"]] # index,timestamp,sender,message,line_number,message_normalized for row in dfRange.itertuples(): message = Message( line_number=row[4], timestamp=row[1], sender=row[2], message=row[3], message_normalized=row[5], ) chunk_messages.append(message) chunk.messages = chunk_messages chunks.append(chunk) return chunks def run_inference(self): """Run inference on both models with parallel processing""" self.logger.info("=" * 80) self.logger.info("RUNNING DUAL QWEN INFERENCE WITH PARALLEL PROCESSING") self.logger.info("=" * 80) chunks = self._create_chunks() if not chunks: self.logger.error("No chunks found. Exiting.") return # Run both models in parallel with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_workers ) as executor: # Submit both model inference tasks qwen3_future = executor.submit( self._run_model_inference, chunks, self.qwen3_url, self.qwen3_model_name, ) # qwen25_future = executor.submit( # self._run_model_inference, # chunks, # self.qwen25_url, # self.qwen25_model_name, # ) # Wait for both to complete qwen3_success, qwen3_errors = qwen3_future.result() # qwen25_success, qwen25_errors = qwen25_future.result() self.logger.info( f"\nQwen3-235B Results: {qwen3_success} success, {qwen3_errors} errors" ) # self.logger.info( # f"Qwen2.5-72B Results: {qwen25_success} success, {qwen25_errors} errors" # ) self.logger.info("\n" + "=" * 80) self.logger.info("INFERENCE COMPLETE") self.logger.info("=" * 80) def _create_system_prompt(self) -> str: """Create system prompt for LLM""" prompt_file = Path(self.output_dir, "system_prompt.txt") with prompt_file.open("r") as file: prompt = file.read() return prompt def _create_response_format(self) -> str: """Create response format for LLM""" format_file = Path(self.output_dir, "response_format.json") with format_file.open("r") as file: response_format = file.read() return response_format def _check_existing_result(self, chunk: Chunk, model_name: str) -> bool: """Check if result already exists in db""" with self.db_lock: sql = """ SELECT COUNT(*) AS num_messages FROM processed WHERE model_name = ? AND chunk_id = ? AND responsive IS NOT NULL """ result = self.cursor.execute(sql, (model_name, chunk.chunk_id)) row = self.cursor.fetchone() if row and row["num_messages"] == len(chunk.messages): return True return False def _save_result(self, chunk: Chunk, results: list[dict], model_name: str): """Save result to db""" # Merge the chunk messages with the results merged_results = {} for msg in chunk.messages: merged_results[msg.line_number] = {"message": msg} for item in results: if item["message_index"] in merged_results.keys(): merged_results[item["message_index"]].update(item) else: merged_results[item["message_index"]] = item.copy() sql = """ INSERT INTO processed ( chunk_id, model_name, message_index, timestamp, sender, message, responsive, reason, criteria, confidence ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) ON CONFLICT (model_name, message_index) DO UPDATE SET chunk_id = excluded.chunk_id, timestamp = excluded.timestamp, sender = excluded.sender, message = excluded.message, responsive = excluded.responsive, reason = excluded.reason, criteria = excluded.criteria, confidence = excluded.confidence """ with self.db_lock: for msg_idx, result_dict in merged_results.items(): msg = result_dict.get("message", None) if msg is None: self.logger.error( f"Result without a message for line {msg_idx}: \n{result_dict}" ) continue elif not isinstance(msg, Message): self.logger.error( f"Message not a Message for line {msg_idx}: \n{result_dict}" ) continue self.cursor.execute( sql, ( result_dict.get("chunk_id", chunk.chunk_id), model_name, msg.line_number, msg.timestamp, msg.sender, msg.message, result_dict.get("responsive", None), result_dict.get("reason", None), str(result_dict.get("criteria", [])), # Convert to string result_dict.get("confidence", None), ), ) self.db.commit() def _process_chunk_batch( self, chunk_batch: List[Chunk], model_url: str, model_name: str, system_prompt: str, response_format: str, ) -> tuple[int, int]: """Process a batch of chunks""" success = 0 errors = 0 for chunk in chunk_batch: # Check if this chunk has already been processed if self._check_existing_result(chunk, model_name): success += 1 continue prompt_messages = [] prompt_messages.append({"role": "system", "content": system_prompt}) prompt_messages.append( {"role": "user", "content": self._create_user_prompt(chunk)} ) payload = { "model": model_name, "messages": prompt_messages, "temperature": self.temperature, "top_p": self.top_p, "max_tokens": self.max_tokens, "response_format": { "type": "json_schema", "json_schema": { "name": "structured_response", "schema": json.loads(response_format), }, }, } self.logger.debug(f"Payload:\n{str(payload)[:200]}...") headers = {"Content-Type": "application/json"} if os.getenv("AI_TOKEN"): headers["Authorization"] = f"Bearer {os.getenv('AI_TOKEN')}" response = None try: response = requests.post( f"{model_url}/v1/chat/completions", headers=headers, json=payload, timeout=600, ) response.raise_for_status() data = response.json() if "error" in data: raise RuntimeError(f"LLM error: {data['error']}") choices = data.get("choices", []) if not choices: raise KeyError("No choices found in response") first_choice = choices[0] if "message" in first_choice and first_choice["message"]: response_text = first_choice["message"].get("content", "") else: response_text = first_choice.get("text", "") if not response_text: raise ValueError("No response found") result = self._parse_response(response_text, chunk, model_name) if result: self._save_result(chunk, result, model_name) success += 1 else: raise RuntimeError("Could not parse result") except Exception as e: self.logger.error( f"Error processing chunk {chunk.chunk_id} with {model_name}: {str(e)}\n{traceback.format_exc()}" ) if response: self.logger.error(f"Response status: {response.status_code}") self.logger.error(f"Response text: {response.text[:500]}...") self._save_result(chunk, [], model_name) errors += 1 return success, errors def _run_model_inference( self, chunks: List[Chunk], model_url: str, model_name: str, ) -> tuple[int, int]: """Run inference on a single model with parallel processing""" system_prompt = self._create_system_prompt() response_format = self._create_response_format() total_success = 0 total_errors = 0 # Split chunks into batches of batch_size chunk_batches = [ chunks[i : i + self.batch_size] for i in range(0, len(chunks), self.batch_size) ] self.logger.info( f"Processing {len(chunks)} chunks in {len(chunk_batches)} batches of up to {self.batch_size} chunks each for {model_name}" ) # Process batches in parallel with max_workers threads with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_workers ) as executor: # Submit all batch processing tasks future_to_batch = {} for i, batch in enumerate(chunk_batches): future = executor.submit( self._process_chunk_batch, batch, model_url, model_name, system_prompt, response_format, ) future_to_batch[future] = i # Process completed batches with progress bar with tqdm(total=len(chunk_batches), desc=f"{model_name} batches") as pbar: for future in concurrent.futures.as_completed(future_to_batch): batch_idx = future_to_batch[future] try: success, errors = future.result() total_success += success total_errors += errors pbar.set_postfix( { "success": total_success, "errors": total_errors, "batch": f"{batch_idx + 1}/{len(chunk_batches)}", } ) except Exception as e: self.logger.error( f"Batch {batch_idx} failed: {str(e)}\n{traceback.format_exc()}" ) total_errors += len(chunk_batches[batch_idx]) finally: pbar.update(1) return total_success, total_errors def _parse_response( self, response_text: str, chunk: Chunk, model_name: str ) -> list[dict]: """Parse model response""" try: parsed = loads(response_text) # Handle both list and dict responses if isinstance(parsed, dict): parsed_list = [parsed] else: parsed_list = cast(List[Dict], parsed) except Exception as e: self.logger.error( f"Error parsing response for chunk {chunk.chunk_id}: {str(e)}" ) return [] if not parsed_list: return [] responses = [] for result in parsed_list: try: responses.append( { "chunk_id": chunk.chunk_id, "message_index": result.get("message_index", None), "responsive": result.get("responsive", None), "reason": result.get("reason", ""), "criteria": result.get("criteria", []), "confidence": result.get("confidence", "low"), } ) except Exception as e: self.logger.error(f"Error parsing response line: {str(e)}\n{result}") return responses def __del__(self): """Clean up database connection""" if hasattr(self, "db"): self.db.close() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description="Run dual Qwen inference with parallel processing" ) parser.add_argument("batch_name") parser.add_argument( "--qwen3-model-name", default="Qwen/Qwen3-30B-A3B-Instruct-2507-FP8" ) parser.add_argument("--qwen3-url", default="http://localhost:8001") parser.add_argument("--qwen25-model-name", default="Qwen/Qwen2.5-72B-Instruct-AWQ") parser.add_argument("--qwen25-url", default="http://localhost:8002") parser.add_argument("--output-dir", default="./pipeline_output") parser.add_argument( "--max-workers", type=int, default=4, help="Maximum number of parallel workers per model", ) parser.add_argument( "--batch-size", type=int, default=4, help="Number of chunks in each batch", ) parser.add_argument( "--temperature", type=float, default=0.01, help="LLM Temperature setting", ) parser.add_argument( "--top-p", type=float, default=0.95, help="LLM Top P setting", ) parser.add_argument( "--max-tokens", type=int, default=8192, help="Maximum number of tokens to use for each batch", ) args = parser.parse_args() runner = InferenceRunner( batch_name=args.batch_name, qwen3_model_name=args.qwen3_model_name, qwen3_url=args.qwen3_url, qwen25_model_name=args.qwen25_model_name, qwen25_url=args.qwen25_url, output_dir=args.output_dir, max_workers=args.max_workers, batch_size=args.batch_size, temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens, ) runner.run_inference()