|
|
@@ -2,150 +2,370 @@
|
|
|
Inference runner for dual Qwen models.
|
|
|
"""
|
|
|
|
|
|
+from collections import defaultdict
|
|
|
import json
|
|
|
+import pandas as pd
|
|
|
import requests
|
|
|
-from typing import List, Dict
|
|
|
+from typing import List, Dict, TypedDict, cast
|
|
|
from pathlib import Path
|
|
|
import logging
|
|
|
from tqdm import tqdm
|
|
|
+import sqlite3
|
|
|
+from json_repair import repair_json, loads
|
|
|
+
|
|
|
+from pipeline.common_defs import Chunk, Message
|
|
|
|
|
|
class InferenceRunner:
|
|
|
"""Run inference on dual Qwen models"""
|
|
|
-
|
|
|
- def __init__(self, qwen3_url: str = "http://localhost:8000",
|
|
|
- qwen25_url: str = "http://localhost:8001",
|
|
|
- output_dir: str = "./pipeline_output"):
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ batch_name: str,
|
|
|
+ qwen3_url: str = "http://localhost:8000",
|
|
|
+ qwen25_url: str = "http://localhost:8001",
|
|
|
+ output_dir: str = "./pipeline_output",
|
|
|
+ ):
|
|
|
+ self.batch_name = batch_name
|
|
|
self.qwen3_url = qwen3_url
|
|
|
self.qwen25_url = qwen25_url
|
|
|
self.output_dir = Path(output_dir)
|
|
|
-
|
|
|
+
|
|
|
self.logger = logging.getLogger("InferenceRunner")
|
|
|
self.logger.setLevel(logging.INFO)
|
|
|
-
|
|
|
+
|
|
|
if not self.logger.handlers:
|
|
|
handler = logging.StreamHandler()
|
|
|
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
|
+ formatter = logging.Formatter(
|
|
|
+ "%(levelname)s - [%(filename)s:%(lineno)d] %(message)s"
|
|
|
+ )
|
|
|
handler.setFormatter(formatter)
|
|
|
self.logger.addHandler(handler)
|
|
|
-
|
|
|
- def load_requests(self, requests_file: str) -> List[Dict]:
|
|
|
- """Load inference requests from JSONL file"""
|
|
|
- requests_data = []
|
|
|
-
|
|
|
- with open(requests_file, "r") as f:
|
|
|
- for line in f:
|
|
|
- requests_data.append(json.loads(line))
|
|
|
-
|
|
|
- self.logger.info(f"Loaded {len(requests_data)} inference requests")
|
|
|
- return requests_data
|
|
|
-
|
|
|
- def run_inference(self, requests_file: str,
|
|
|
- temperature: float = 0.1,
|
|
|
- max_tokens: int = 500):
|
|
|
+
|
|
|
+ self.db = sqlite3.connect(
|
|
|
+ self.output_dir / f"{batch_name}_processed.db3",
|
|
|
+ check_same_thread=False,
|
|
|
+ )
|
|
|
+ self.cursor = self.db.cursor()
|
|
|
+ sql = """
|
|
|
+CREATE TABLE IF NOT EXISTS processed (
|
|
|
+ chunk_id INTEGER,
|
|
|
+ model_name: TEXT,
|
|
|
+ message_index INTEGER,
|
|
|
+ timestamp DATETIME,
|
|
|
+ sender: TEXT,
|
|
|
+ message TEXT,
|
|
|
+ responsive BOOLEAN,
|
|
|
+ reason TEXT,
|
|
|
+ criteria TEXT,
|
|
|
+ confidence TEXT,
|
|
|
+ PRIMARY KEY (model_name, message_index)
|
|
|
+);
|
|
|
+ """
|
|
|
+ self.cursor.execute(sql)
|
|
|
+ self.db.row_factory = sqlite3.Row
|
|
|
+ self.logger.info("summary database initialized")
|
|
|
+
|
|
|
+ def _create_user_prompt(self, chunk: Chunk) -> str:
|
|
|
+ """Create inference request for a chunk"""
|
|
|
+ # Format messages
|
|
|
+ messages_text = ""
|
|
|
+ for msg in chunk.messages:
|
|
|
+ messages_text += (
|
|
|
+ f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message}\n"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Create full prompt
|
|
|
+ prompt = f"""
|
|
|
+Review and classify the following messages.
|
|
|
+
|
|
|
+MESSAGES TO REVIEW (Lines {chunk.start_line}-{chunk.end_line}):
|
|
|
+
|
|
|
+{messages_text}
|
|
|
+
|
|
|
+Provide your response as valid JSON following the specified format.
|
|
|
+"""
|
|
|
+ return prompt
|
|
|
+
|
|
|
+ def _create_chunks(self) -> list[Chunk]:
|
|
|
+ with open("pipeline_output/chunks.json", "r") as f:
|
|
|
+ chunk_data = json.load(f)
|
|
|
+
|
|
|
+ msg_df = pd.read_csv(self.output_dir / "preprocessed_messages.csv")
|
|
|
+
|
|
|
+ # Reconstruct chunks (simplified)
|
|
|
+ chunks = []
|
|
|
+ for item in chunk_data["filtered_chunks"][:10]: # First 10 for testing
|
|
|
+ chunk = Chunk(
|
|
|
+ chunk_id=item["chunk_id"],
|
|
|
+ start_line=item["start_line"],
|
|
|
+ end_line=item["end_line"],
|
|
|
+ messages=[],
|
|
|
+ combined_text="",
|
|
|
+ timestamp_start=item["timestamp_start"],
|
|
|
+ timestamp_end=item["timetamp_end"],
|
|
|
+ )
|
|
|
+ chunk_messages = []
|
|
|
+ dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"] - 1]
|
|
|
+ for index, row in dfRange.itertuples():
|
|
|
+ message = Message(
|
|
|
+ (index + 1),
|
|
|
+ row["timestamp"],
|
|
|
+ row["sender"],
|
|
|
+ row["message_normalized"],
|
|
|
+ )
|
|
|
+
|
|
|
+ chunks.append(chunk)
|
|
|
+ return chunks
|
|
|
+
|
|
|
+ def run_inference(self, temperature: float = 0.1, max_tokens: int = 2048):
|
|
|
"""Run inference on both models"""
|
|
|
self.logger.info("=" * 80)
|
|
|
self.logger.info("RUNNING DUAL QWEN INFERENCE")
|
|
|
self.logger.info("=" * 80)
|
|
|
-
|
|
|
- requests_data = self.load_requests(requests_file)
|
|
|
-
|
|
|
+
|
|
|
+ chunks = self._create_chunks()
|
|
|
+
|
|
|
self.logger.info("\nRunning Qwen 3 235B inference...")
|
|
|
- qwen3_results = self._run_model_inference(
|
|
|
- requests_data, self.qwen3_url, "Qwen3-235B", temperature, max_tokens
|
|
|
+ self._run_model_inference(
|
|
|
+ chunks, self.qwen3_url, "Qwen3-235B", temperature, max_tokens
|
|
|
)
|
|
|
-
|
|
|
- qwen3_file = self.output_dir / "qwen3_results.jsonl"
|
|
|
- self._save_results(qwen3_results, qwen3_file)
|
|
|
-
|
|
|
+
|
|
|
self.logger.info("\nRunning Qwen 2.5 72B inference...")
|
|
|
- qwen25_results = self._run_model_inference(
|
|
|
- requests_data, self.qwen25_url, "Qwen2.5-72B", temperature, max_tokens
|
|
|
+ self._run_model_inference(
|
|
|
+ chunks, self.qwen25_url, "Qwen2.5-72B", temperature, max_tokens
|
|
|
)
|
|
|
-
|
|
|
- qwen25_file = self.output_dir / "qwen25_results.jsonl"
|
|
|
- self._save_results(qwen25_results, qwen25_file)
|
|
|
-
|
|
|
+
|
|
|
self.logger.info("\n" + "=" * 80)
|
|
|
self.logger.info("INFERENCE COMPLETE")
|
|
|
self.logger.info("=" * 80)
|
|
|
-
|
|
|
- return str(qwen3_file), str(qwen25_file)
|
|
|
-
|
|
|
- def _run_model_inference(self, requests_data: List[Dict],
|
|
|
- model_url: str, model_name: str,
|
|
|
- temperature: float, max_tokens: int) -> List[Dict]:
|
|
|
+
|
|
|
+ def _create_system_prompt(self) -> str:
|
|
|
+ """Create system prompt for LLM"""
|
|
|
+ prompt = ""
|
|
|
+ with Path(self.output_dir, "system_prompt.txt").open("r") as file:
|
|
|
+ prompt = file.read()
|
|
|
+ return prompt
|
|
|
+
|
|
|
+ def _create_response_format(self) -> str:
|
|
|
+ """Create response format for LLM"""
|
|
|
+ response_format = ""
|
|
|
+ with Path(self.output_dir, "response_format.json").open() as file:
|
|
|
+ response_format = file.read()
|
|
|
+ return response_format
|
|
|
+
|
|
|
+ def _check_existing_result(self, chunk: Chunk, model_name) -> bool:
|
|
|
+ """Check if result already exists in db"""
|
|
|
+ sql = """
|
|
|
+SELECT
|
|
|
+ COUNT(*) AS num_messages
|
|
|
+FROM
|
|
|
+ processed
|
|
|
+WHERE
|
|
|
+ model_name = ?
|
|
|
+ AND chunk_id = ?
|
|
|
+ AND responsive IS NOT NULL
|
|
|
+ """
|
|
|
+ result = self.cursor.execute(sql, (model_name, chunk.chunk_id))
|
|
|
+ row: dict = self.cursor.fetchone()
|
|
|
+ if row and row["0"] == len(chunk.messages):
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+ def _save_result(self, chunk: Chunk, results: list[dict], model_name: str):
|
|
|
+ """Save result to db"""
|
|
|
+ # merge the chunk messages with the results
|
|
|
+ merged_results = {}
|
|
|
+ for msg in chunk.messages:
|
|
|
+ merged_results[msg.line_number] = {"message": msg}
|
|
|
+ for item in results:
|
|
|
+ if item["message_index"] in merged_results:
|
|
|
+ merged_results[item["message_index"]].update(item)
|
|
|
+ else:
|
|
|
+ merged_results[item["message_index"]] = item.copy()
|
|
|
+
|
|
|
+ sql = """
|
|
|
+INSERT INTO processed (
|
|
|
+ chunk_id,
|
|
|
+ model_name,
|
|
|
+ message_index,
|
|
|
+ timestamp,
|
|
|
+ sender,
|
|
|
+ message,
|
|
|
+ responsive,
|
|
|
+ reason,
|
|
|
+ criteria,
|
|
|
+ confidence
|
|
|
+) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ? )
|
|
|
+ON CONFLICT (model_name, message_index) DO UPDATE SET
|
|
|
+ chunk_id = excluded.chunk_id,
|
|
|
+ timestamp = excluded.timestamp,
|
|
|
+ sender = excluded.sender,
|
|
|
+ message = excluded.message,
|
|
|
+ responsive = excluded.responsive,
|
|
|
+ reason = excluded.reason,
|
|
|
+ criteria = excluded.criteria,
|
|
|
+ confidence = excluded.confidence
|
|
|
+ """
|
|
|
+ for result in merged_results:
|
|
|
+ msg = result.get("message", None)
|
|
|
+
|
|
|
+ if msg is None or not isinstance(msg, Message):
|
|
|
+ self.logger.error(
|
|
|
+ f"somehow we have a result without a message: \n{result}"
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ self.cursor.execute(
|
|
|
+ sql,
|
|
|
+ (
|
|
|
+ result.get("chunk_id", None),
|
|
|
+ model_name,
|
|
|
+ msg.line_number,
|
|
|
+ msg.timestamp,
|
|
|
+ msg.sender,
|
|
|
+ msg.message,
|
|
|
+ result.get("responsive", None),
|
|
|
+ result.get("reason", None),
|
|
|
+ result.get("criteria", None),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ def _run_model_inference(
|
|
|
+ self,
|
|
|
+ chunks: List[Chunk],
|
|
|
+ model_url: str,
|
|
|
+ model_name: str,
|
|
|
+ temperature: float,
|
|
|
+ max_tokens: int,
|
|
|
+ ):
|
|
|
"""Run inference on a single model"""
|
|
|
- results = []
|
|
|
-
|
|
|
- for req in tqdm(requests_data, desc=f"{model_name} inference"):
|
|
|
+ system_prompt = self._create_system_prompt()
|
|
|
+ response_format = self._create_response_format()
|
|
|
+
|
|
|
+ success = 0
|
|
|
+ errors = 0
|
|
|
+
|
|
|
+ for chunk in tqdm(chunks, desc=f"{model_name} inference"):
|
|
|
+ # check if this chunk has already been processed
|
|
|
+ if self._check_existing_result(chunk, model_name):
|
|
|
+ continue
|
|
|
+
|
|
|
+ prompt_messages = []
|
|
|
+ prompt_messages.append({"role": "system", "content": system_prompt})
|
|
|
+ prompt_messages.append(
|
|
|
+ {"role": "user", "content": self._create_user_prompt(chunk)}
|
|
|
+ )
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": model_name,
|
|
|
+ "messages": prompt_messages,
|
|
|
+ "temperature": temperature,
|
|
|
+ "max_tokens": max_tokens,
|
|
|
+ "response_format": {
|
|
|
+ "type": "json_schema",
|
|
|
+ "json_schema": {
|
|
|
+ "name": "structured_response",
|
|
|
+ "schema": json.loads(response_format),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ # "top_p",
|
|
|
+ # "top_k",
|
|
|
+ # "frequency_penalty",
|
|
|
+ # "presence_penalty",
|
|
|
+ # # "stop",
|
|
|
+ # # "skip_special_tokens",
|
|
|
+ # "enable_thinking",
|
|
|
+
|
|
|
+ headers = {"Content-Type": "application/json"}
|
|
|
+
|
|
|
+ response = "Not Processed"
|
|
|
+
|
|
|
try:
|
|
|
response = requests.post(
|
|
|
- f"{model_url}/v1/completions",
|
|
|
- json={
|
|
|
- "prompt": req["prompt"],
|
|
|
- "max_tokens": max_tokens,
|
|
|
- "temperature": temperature
|
|
|
- },
|
|
|
- timeout=60
|
|
|
+ f"{model_url}/v1/completions", headers=headers, json=payload
|
|
|
)
|
|
|
-
|
|
|
- if response.status_code == 200:
|
|
|
- result = self._parse_response(response.json(), req, model_name)
|
|
|
- results.append(result)
|
|
|
+ response.raise_for_status()
|
|
|
+ # logger.log(LEVEL_TRACE, f"Response {response.status_code}\n{response.text}")
|
|
|
+
|
|
|
+ data = response.json()
|
|
|
+ if "error" in data:
|
|
|
+ raise RuntimeError("LLM error")
|
|
|
+
|
|
|
+ choices = data.get("choices", [])
|
|
|
+ if not choices:
|
|
|
+ raise KeyError("No choices found in response")
|
|
|
+
|
|
|
+ first_choice = choices[0]
|
|
|
+ if "message" in first_choice and first_choice["message"]:
|
|
|
+ response_text = first_choice["message"].get("content", "")
|
|
|
else:
|
|
|
- results.append(self._create_error_result(req, model_name))
|
|
|
-
|
|
|
+ response_text = first_choice.get("text", "")
|
|
|
+
|
|
|
+ if not response_text:
|
|
|
+ raise ValueError("No response found")
|
|
|
+
|
|
|
+ result = self._parse_response(response_text, chunk, model_name)
|
|
|
+ if result:
|
|
|
+ success += 1
|
|
|
+ else:
|
|
|
+ raise RuntimeError("Could not parse result")
|
|
|
+
|
|
|
except Exception as e:
|
|
|
- self.logger.error(f"Exception for chunk {req['chunk_id']}: {e}")
|
|
|
- results.append(self._create_error_result(req, model_name))
|
|
|
-
|
|
|
- return results
|
|
|
-
|
|
|
- def _parse_response(self, response: Dict, request: Dict, model_name: str) -> Dict:
|
|
|
+ self.logger.error(
|
|
|
+ f"Error processing chunk {chunk.chunk_id}: \nResponse was:\n{response}\n{e.with_traceback}"
|
|
|
+ )
|
|
|
+ self._save_result(chunk, [], model_name)
|
|
|
+ errors += 1
|
|
|
+ return success, errors
|
|
|
+
|
|
|
+ def _parse_response(
|
|
|
+ self, response_text, chunk: Chunk, model_name: str
|
|
|
+ ) -> list[dict]:
|
|
|
"""Parse model response"""
|
|
|
+ parsed_list = {}
|
|
|
try:
|
|
|
- text = response["choices"][0]["text"]
|
|
|
- parsed = json.loads(text)
|
|
|
-
|
|
|
- return {
|
|
|
- "chunk_id": request["chunk_id"],
|
|
|
- "responsive_line_numbers": parsed.get("responsive_line_numbers", []),
|
|
|
- "reasoning": parsed.get("reasoning", ""),
|
|
|
- "confidence": parsed.get("confidence", "medium"),
|
|
|
- "model_name": model_name
|
|
|
- }
|
|
|
- except Exception:
|
|
|
- return self._create_error_result(request, model_name)
|
|
|
-
|
|
|
- def _create_error_result(self, request: Dict, model_name: str) -> Dict:
|
|
|
- """Create error result"""
|
|
|
- return {
|
|
|
- "chunk_id": request["chunk_id"],
|
|
|
- "responsive_line_numbers": [],
|
|
|
- "reasoning": "Error during inference",
|
|
|
- "confidence": "low",
|
|
|
- "model_name": model_name,
|
|
|
- "error": True
|
|
|
- }
|
|
|
-
|
|
|
- def _save_results(self, results: List[Dict], filepath: Path):
|
|
|
- """Save results to JSONL"""
|
|
|
- with open(filepath, "w") as f:
|
|
|
- for result in results:
|
|
|
- f.write(json.dumps(result) + "\n")
|
|
|
-
|
|
|
- self.logger.info(f"Saved {len(results)} results to {filepath}")
|
|
|
+ parsed = loads(response_text)
|
|
|
+ parsed_list = cast(List[Dict], parsed)
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(f"Errror parsing response for chunk {chunk.chunk_id}")
|
|
|
+
|
|
|
+ if not parsed_list:
|
|
|
+ return []
|
|
|
+
|
|
|
+ responses = []
|
|
|
+ for result in parsed_list:
|
|
|
+ try:
|
|
|
+ responses.append(
|
|
|
+ {
|
|
|
+ "chunk_id": chunk.chunk_id,
|
|
|
+ "message_index": result.get("message_index", None),
|
|
|
+ "responsive": result.get("responsive", None),
|
|
|
+ "reason": result.get("reason", ""),
|
|
|
+ "criteria": result.get("criteria", []),
|
|
|
+ "confidence": result.get("confidence", "low"),
|
|
|
+ }
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(
|
|
|
+ f"Error parsing response line: \n{e.with_traceback}\n{result}"
|
|
|
+ )
|
|
|
+ return responses
|
|
|
+
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import argparse
|
|
|
-
|
|
|
+
|
|
|
parser = argparse.ArgumentParser(description="Run dual Qwen inference")
|
|
|
- parser.add_argument("requests_file", help="Path to inference requests JSONL")
|
|
|
- parser.add_argument("--qwen3-url", default="http://localhost:8000")
|
|
|
- parser.add_argument("--qwen25-url", default="http://localhost:8001")
|
|
|
+ parser.add_argument("batch_name")
|
|
|
+ parser.add_argument("--qwen3-url", default="http://localhost:8001")
|
|
|
+ parser.add_argument("--qwen25-url", default="http://localhost:8002")
|
|
|
parser.add_argument("--output-dir", default="./pipeline_output")
|
|
|
-
|
|
|
+
|
|
|
args = parser.parse_args()
|
|
|
-
|
|
|
- runner = InferenceRunner(args.qwen3_url, args.qwen25_url, args.output_dir)
|
|
|
- runner.run_inference(args.requests_file)
|
|
|
+
|
|
|
+ runner = InferenceRunner(
|
|
|
+ args.batch_name, args.qwen3_url, args.qwen25_url, args.output_dir
|
|
|
+ )
|
|
|
+ runner.run_inference()
|