| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563 |
- """
- Improved Inference runner for dual Qwen models with parallel processing.
- """
- import json
- import os
- import traceback
- import pandas as pd
- import requests
- from typing import List, Dict, cast
- from pathlib import Path
- import logging
- from tqdm import tqdm
- import sqlite3
- from json_repair import loads
- import concurrent.futures
- import threading
- from pipeline.common_defs import Chunk, Message
- class InferenceRunner:
- """Run inference on dual Qwen models with parallel processing"""
- def __init__(
- self,
- batch_name: str,
- qwen3_model_name: str = "",
- qwen3_url: str = "http://localhost:8001",
- qwen25_model_name: str = "",
- qwen25_url: str = "http://localhost:8002",
- output_dir: str = "./pipeline_output",
- max_workers: int = 4, # New parameter for parallel processing
- batch_size: int = 4,
- temperature: float = 0.01,
- top_p: float = 0.95,
- max_tokens: int = 8192,
- ):
- self.batch_name = batch_name
- self.qwen3_model_name = qwen3_model_name
- self.qwen3_url = qwen3_url
- self.qwen25_model_name = qwen25_model_name
- self.qwen25_url = qwen25_url
- self.output_dir = Path(output_dir)
- self.max_workers = max_workers
- self.batch_size = batch_size
- self.max_tokens = max_tokens
- self.temperature = temperature
- self.top_p = top_p
- self.logger = logging.getLogger("InferenceRunner")
- self.logger.setLevel(logging.DEBUG)
- if not self.logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter(
- "%(levelname)s - [%(filename)s:%(lineno)d] %(message)s"
- )
- handler.setFormatter(formatter)
- self.logger.addHandler(handler)
- # Create output directory if it doesn't exist
- self.output_dir.mkdir(parents=True, exist_ok=True)
- self.db = sqlite3.connect(
- self.output_dir / f"{batch_name}_processed.db3",
- check_same_thread=False,
- )
- self.db.row_factory = sqlite3.Row
- self.cursor = self.db.cursor()
- # Fixed SQL syntax error (removed colon after TEXT)
- sql = """
- CREATE TABLE IF NOT EXISTS processed (
- chunk_id INTEGER,
- model_name TEXT,
- message_index INTEGER,
- timestamp DATETIME,
- sender TEXT,
- message TEXT,
- responsive BOOLEAN,
- reason TEXT,
- criteria TEXT,
- confidence TEXT,
- PRIMARY KEY (model_name, message_index)
- );
- """
- self.cursor.execute(sql)
- self.db.commit() # Added commit
- self.logger.info("Summary database initialized")
- # Thread lock for database operations
- self.db_lock = threading.Lock()
- def _create_user_prompt(self, chunk: Chunk) -> str:
- """Create inference request for a chunk"""
- # Format messages
- messages_text = ""
- for msg in chunk.messages:
- messages_text += f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message_normalized}\n"
- # Create full prompt
- prompt = f"""
- Review and classify the following messages.
- MESSAGES TO REVIEW (Lines {chunk.start_line}-{chunk.end_line}):
- {messages_text}
- Provide your response as valid JSON following the specified format.
- """
- return prompt
- def _create_chunks(self) -> list[Chunk]:
- """Create chunks from preprocessed data"""
- chunks_file = self.output_dir / "chunks.json"
- with open(chunks_file, "r") as f:
- chunk_data = json.load(f)
- msg_file = self.output_dir / "preprocessed_messages.csv"
- msg_df = pd.read_csv(msg_file)
- # Reconstruct chunks
- chunks = []
- for item in chunk_data[:2]: # First 10 for testing
- chunk = Chunk(
- chunk_id=item["chunk_id"],
- start_line=item["start_line"],
- end_line=item["end_line"],
- messages=[],
- combined_text="",
- timestamp_start=item["timestamp_start"],
- timestamp_end=item.get("timestamp_end", ""),
- )
- chunk_messages = []
- dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"]]
- # index,timestamp,sender,message,line_number,message_normalized
- for row in dfRange.itertuples():
- message = Message(
- line_number=row[4],
- timestamp=row[1],
- sender=row[2],
- message=row[3],
- message_normalized=row[5],
- )
- chunk_messages.append(message)
- chunk.messages = chunk_messages
- chunks.append(chunk)
- return chunks
- def run_inference(self):
- """Run inference on both models with parallel processing"""
- self.logger.info("=" * 80)
- self.logger.info("RUNNING DUAL QWEN INFERENCE WITH PARALLEL PROCESSING")
- self.logger.info("=" * 80)
- chunks = self._create_chunks()
- if not chunks:
- self.logger.error("No chunks found. Exiting.")
- return
- # Run both models in parallel
- with concurrent.futures.ThreadPoolExecutor(
- max_workers=self.max_workers
- ) as executor:
- # Submit both model inference tasks
- qwen3_future = executor.submit(
- self._run_model_inference,
- chunks,
- self.qwen3_url,
- self.qwen3_model_name,
- )
- # qwen25_future = executor.submit(
- # self._run_model_inference,
- # chunks,
- # self.qwen25_url,
- # self.qwen25_model_name,
- # )
- # Wait for both to complete
- qwen3_success, qwen3_errors = qwen3_future.result()
- # qwen25_success, qwen25_errors = qwen25_future.result()
- self.logger.info(
- f"\nQwen3-235B Results: {qwen3_success} success, {qwen3_errors} errors"
- )
- # self.logger.info(
- # f"Qwen2.5-72B Results: {qwen25_success} success, {qwen25_errors} errors"
- # )
- self.logger.info("\n" + "=" * 80)
- self.logger.info("INFERENCE COMPLETE")
- self.logger.info("=" * 80)
- def _create_system_prompt(self) -> str:
- """Create system prompt for LLM"""
- prompt_file = Path(self.output_dir, "system_prompt.txt")
- with prompt_file.open("r") as file:
- prompt = file.read()
- return prompt
- def _create_response_format(self) -> str:
- """Create response format for LLM"""
- format_file = Path(self.output_dir, "response_format.json")
- with format_file.open("r") as file:
- response_format = file.read()
- return response_format
- def _check_existing_result(self, chunk: Chunk, model_name: str) -> bool:
- """Check if result already exists in db"""
- with self.db_lock:
- sql = """
- SELECT
- COUNT(*) AS num_messages
- FROM
- processed
- WHERE
- model_name = ?
- AND chunk_id = ?
- AND responsive IS NOT NULL
- """
- result = self.cursor.execute(sql, (model_name, chunk.chunk_id))
- row = self.cursor.fetchone()
- if row and row["num_messages"] == len(chunk.messages):
- return True
- return False
- def _save_result(self, chunk: Chunk, results: list[dict], model_name: str):
- """Save result to db"""
- # Merge the chunk messages with the results
- merged_results = {}
- for msg in chunk.messages:
- merged_results[msg.line_number] = {"message": msg}
- for item in results:
- if item["message_index"] in merged_results.keys():
- merged_results[item["message_index"]].update(item)
- else:
- merged_results[item["message_index"]] = item.copy()
- sql = """
- INSERT INTO processed (
- chunk_id,
- model_name,
- message_index,
- timestamp,
- sender,
- message,
- responsive,
- reason,
- criteria,
- confidence
- ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ? )
- ON CONFLICT (model_name, message_index) DO UPDATE SET
- chunk_id = excluded.chunk_id,
- timestamp = excluded.timestamp,
- sender = excluded.sender,
- message = excluded.message,
- responsive = excluded.responsive,
- reason = excluded.reason,
- criteria = excluded.criteria,
- confidence = excluded.confidence
- """
- with self.db_lock:
- for msg_idx, result_dict in merged_results.items():
- msg = result_dict.get("message", None)
- if msg is None:
- self.logger.error(
- f"Result without a message for line {msg_idx}: \n{result_dict}"
- )
- continue
- elif not isinstance(msg, Message):
- self.logger.error(
- f"Message not a Message for line {msg_idx}: \n{result_dict}"
- )
- continue
- self.cursor.execute(
- sql,
- (
- result_dict.get("chunk_id", chunk.chunk_id),
- model_name,
- msg.line_number,
- msg.timestamp,
- msg.sender,
- msg.message,
- result_dict.get("responsive", None),
- result_dict.get("reason", None),
- str(result_dict.get("criteria", [])), # Convert to string
- result_dict.get("confidence", None),
- ),
- )
- self.db.commit()
- def _process_chunk_batch(
- self,
- chunk_batch: List[Chunk],
- model_url: str,
- model_name: str,
- system_prompt: str,
- response_format: str,
- ) -> tuple[int, int]:
- """Process a batch of chunks"""
- success = 0
- errors = 0
- for chunk in chunk_batch:
- # Check if this chunk has already been processed
- if self._check_existing_result(chunk, model_name):
- success += 1
- continue
- prompt_messages = []
- prompt_messages.append({"role": "system", "content": system_prompt})
- prompt_messages.append(
- {"role": "user", "content": self._create_user_prompt(chunk)}
- )
- payload = {
- "model": model_name,
- "messages": prompt_messages,
- "temperature": self.temperature,
- "top_p": self.top_p,
- "max_tokens": self.max_tokens,
- "response_format": {
- "type": "json_schema",
- "json_schema": {
- "name": "structured_response",
- "schema": json.loads(response_format),
- },
- },
- }
- self.logger.debug(f"Payload:\n{str(payload)[:200]}...")
- headers = {"Content-Type": "application/json"}
- if os.getenv("AI_TOKEN"):
- headers["Authorization"] = f"Bearer {os.getenv('AI_TOKEN')}"
- response = None
- try:
- response = requests.post(
- f"{model_url}/v1/chat/completions",
- headers=headers,
- json=payload,
- timeout=600,
- )
- response.raise_for_status()
- data = response.json()
- if "error" in data:
- raise RuntimeError(f"LLM error: {data['error']}")
- choices = data.get("choices", [])
- if not choices:
- raise KeyError("No choices found in response")
- first_choice = choices[0]
- if "message" in first_choice and first_choice["message"]:
- response_text = first_choice["message"].get("content", "")
- else:
- response_text = first_choice.get("text", "")
- if not response_text:
- raise ValueError("No response found")
- result = self._parse_response(response_text, chunk, model_name)
- if result:
- self._save_result(chunk, result, model_name)
- success += 1
- else:
- raise RuntimeError("Could not parse result")
- except Exception as e:
- self.logger.error(
- f"Error processing chunk {chunk.chunk_id} with {model_name}: {str(e)}\n{traceback.format_exc()}"
- )
- if response:
- self.logger.error(f"Response status: {response.status_code}")
- self.logger.error(f"Response text: {response.text[:500]}...")
- self._save_result(chunk, [], model_name)
- errors += 1
- return success, errors
- def _run_model_inference(
- self,
- chunks: List[Chunk],
- model_url: str,
- model_name: str,
- ) -> tuple[int, int]:
- """Run inference on a single model with parallel processing"""
- system_prompt = self._create_system_prompt()
- response_format = self._create_response_format()
- total_success = 0
- total_errors = 0
- # Split chunks into batches of batch_size
- chunk_batches = [
- chunks[i : i + self.batch_size]
- for i in range(0, len(chunks), self.batch_size)
- ]
- self.logger.info(
- f"Processing {len(chunks)} chunks in {len(chunk_batches)} batches of up to {self.batch_size} chunks each for {model_name}"
- )
- # Process batches in parallel with max_workers threads
- with concurrent.futures.ThreadPoolExecutor(
- max_workers=self.max_workers
- ) as executor:
- # Submit all batch processing tasks
- future_to_batch = {}
- for i, batch in enumerate(chunk_batches):
- future = executor.submit(
- self._process_chunk_batch,
- batch,
- model_url,
- model_name,
- system_prompt,
- response_format,
- )
- future_to_batch[future] = i
- # Process completed batches with progress bar
- with tqdm(total=len(chunk_batches), desc=f"{model_name} batches") as pbar:
- for future in concurrent.futures.as_completed(future_to_batch):
- batch_idx = future_to_batch[future]
- try:
- success, errors = future.result()
- total_success += success
- total_errors += errors
- pbar.set_postfix(
- {
- "success": total_success,
- "errors": total_errors,
- "batch": f"{batch_idx + 1}/{len(chunk_batches)}",
- }
- )
- except Exception as e:
- self.logger.error(
- f"Batch {batch_idx} failed: {str(e)}\n{traceback.format_exc()}"
- )
- total_errors += len(chunk_batches[batch_idx])
- finally:
- pbar.update(1)
- return total_success, total_errors
- def _parse_response(
- self, response_text: str, chunk: Chunk, model_name: str
- ) -> list[dict]:
- """Parse model response"""
- try:
- parsed = loads(response_text)
- # Handle both list and dict responses
- if isinstance(parsed, dict):
- parsed_list = [parsed]
- else:
- parsed_list = cast(List[Dict], parsed)
- except Exception as e:
- self.logger.error(
- f"Error parsing response for chunk {chunk.chunk_id}: {str(e)}"
- )
- return []
- if not parsed_list:
- return []
- responses = []
- for result in parsed_list:
- try:
- responses.append(
- {
- "chunk_id": chunk.chunk_id,
- "message_index": result.get("message_index", None),
- "responsive": result.get("responsive", None),
- "reason": result.get("reason", ""),
- "criteria": result.get("criteria", []),
- "confidence": result.get("confidence", "low"),
- }
- )
- except Exception as e:
- self.logger.error(f"Error parsing response line: {str(e)}\n{result}")
- return responses
- def __del__(self):
- """Clean up database connection"""
- if hasattr(self, "db"):
- self.db.close()
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser(
- description="Run dual Qwen inference with parallel processing"
- )
- parser.add_argument("batch_name")
- parser.add_argument(
- "--qwen3-model-name", default="Qwen/Qwen3-30B-A3B-Instruct-2507-FP8"
- )
- parser.add_argument("--qwen3-url", default="http://localhost:8001")
- parser.add_argument("--qwen25-model-name", default="Qwen/Qwen2.5-72B-Instruct-AWQ")
- parser.add_argument("--qwen25-url", default="http://localhost:8002")
- parser.add_argument("--output-dir", default="./pipeline_output")
- parser.add_argument(
- "--max-workers",
- type=int,
- default=4,
- help="Maximum number of parallel workers per model",
- )
- parser.add_argument(
- "--batch-size",
- type=int,
- default=4,
- help="Number of chunks in each batch",
- )
- parser.add_argument(
- "--temperature",
- type=float,
- default=0.01,
- help="LLM Temperature setting",
- )
- parser.add_argument(
- "--top-p",
- type=float,
- default=0.95,
- help="LLM Top P setting",
- )
- parser.add_argument(
- "--max-tokens",
- type=int,
- default=8192,
- help="Maximum number of tokens to use for each batch",
- )
- args = parser.parse_args()
- runner = InferenceRunner(
- batch_name=args.batch_name,
- qwen3_model_name=args.qwen3_model_name,
- qwen3_url=args.qwen3_url,
- qwen25_model_name=args.qwen25_model_name,
- qwen25_url=args.qwen25_url,
- output_dir=args.output_dir,
- max_workers=args.max_workers,
- batch_size=args.batch_size,
- temperature=args.temperature,
- top_p=args.top_p,
- max_tokens=args.max_tokens,
- )
- runner.run_inference()
|