inference_runner.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. """
  2. Inference runner for dual Qwen models.
  3. """
  4. from collections import defaultdict
  5. import json
  6. import pandas as pd
  7. import requests
  8. from typing import List, Dict, TypedDict, cast
  9. from pathlib import Path
  10. import logging
  11. from tqdm import tqdm
  12. import sqlite3
  13. from json_repair import repair_json, loads
  14. from pipeline.common_defs import Chunk, Message
  15. class InferenceRunner:
  16. """Run inference on dual Qwen models"""
  17. def __init__(
  18. self,
  19. batch_name: str,
  20. qwen3_url: str = "http://localhost:8000",
  21. qwen25_url: str = "http://localhost:8001",
  22. output_dir: str = "./pipeline_output",
  23. ):
  24. self.batch_name = batch_name
  25. self.qwen3_url = qwen3_url
  26. self.qwen25_url = qwen25_url
  27. self.output_dir = Path(output_dir)
  28. self.logger = logging.getLogger("InferenceRunner")
  29. self.logger.setLevel(logging.INFO)
  30. if not self.logger.handlers:
  31. handler = logging.StreamHandler()
  32. formatter = logging.Formatter(
  33. "%(levelname)s - [%(filename)s:%(lineno)d] %(message)s"
  34. )
  35. handler.setFormatter(formatter)
  36. self.logger.addHandler(handler)
  37. self.db = sqlite3.connect(
  38. self.output_dir / f"{batch_name}_processed.db3",
  39. check_same_thread=False,
  40. )
  41. self.cursor = self.db.cursor()
  42. sql = """
  43. CREATE TABLE IF NOT EXISTS processed (
  44. chunk_id INTEGER,
  45. model_name: TEXT,
  46. message_index INTEGER,
  47. timestamp DATETIME,
  48. sender: TEXT,
  49. message TEXT,
  50. responsive BOOLEAN,
  51. reason TEXT,
  52. criteria TEXT,
  53. confidence TEXT,
  54. PRIMARY KEY (model_name, message_index)
  55. );
  56. """
  57. self.cursor.execute(sql)
  58. self.db.row_factory = sqlite3.Row
  59. self.logger.info("summary database initialized")
  60. def _create_user_prompt(self, chunk: Chunk) -> str:
  61. """Create inference request for a chunk"""
  62. # Format messages
  63. messages_text = ""
  64. for msg in chunk.messages:
  65. messages_text += (
  66. f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message}\n"
  67. )
  68. # Create full prompt
  69. prompt = f"""
  70. Review and classify the following messages.
  71. MESSAGES TO REVIEW (Lines {chunk.start_line}-{chunk.end_line}):
  72. {messages_text}
  73. Provide your response as valid JSON following the specified format.
  74. """
  75. return prompt
  76. def _create_chunks(self) -> list[Chunk]:
  77. with open("pipeline_output/chunks.json", "r") as f:
  78. chunk_data = json.load(f)
  79. msg_df = pd.read_csv(self.output_dir / "preprocessed_messages.csv")
  80. # Reconstruct chunks (simplified)
  81. chunks = []
  82. for item in chunk_data["filtered_chunks"][:10]: # First 10 for testing
  83. chunk = Chunk(
  84. chunk_id=item["chunk_id"],
  85. start_line=item["start_line"],
  86. end_line=item["end_line"],
  87. messages=[],
  88. combined_text="",
  89. timestamp_start=item["timestamp_start"],
  90. timestamp_end=item["timetamp_end"],
  91. )
  92. chunk_messages = []
  93. dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"] - 1]
  94. for index, row in dfRange.itertuples():
  95. message = Message(
  96. (index + 1),
  97. row["timestamp"],
  98. row["sender"],
  99. row["message_normalized"],
  100. )
  101. chunks.append(chunk)
  102. return chunks
  103. def run_inference(self, temperature: float = 0.1, max_tokens: int = 2048):
  104. """Run inference on both models"""
  105. self.logger.info("=" * 80)
  106. self.logger.info("RUNNING DUAL QWEN INFERENCE")
  107. self.logger.info("=" * 80)
  108. chunks = self._create_chunks()
  109. self.logger.info("\nRunning Qwen 3 235B inference...")
  110. self._run_model_inference(
  111. chunks, self.qwen3_url, "Qwen3-235B", temperature, max_tokens
  112. )
  113. self.logger.info("\nRunning Qwen 2.5 72B inference...")
  114. self._run_model_inference(
  115. chunks, self.qwen25_url, "Qwen2.5-72B", temperature, max_tokens
  116. )
  117. self.logger.info("\n" + "=" * 80)
  118. self.logger.info("INFERENCE COMPLETE")
  119. self.logger.info("=" * 80)
  120. def _create_system_prompt(self) -> str:
  121. """Create system prompt for LLM"""
  122. prompt = ""
  123. with Path(self.output_dir, "system_prompt.txt").open("r") as file:
  124. prompt = file.read()
  125. return prompt
  126. def _create_response_format(self) -> str:
  127. """Create response format for LLM"""
  128. response_format = ""
  129. with Path(self.output_dir, "response_format.json").open() as file:
  130. response_format = file.read()
  131. return response_format
  132. def _check_existing_result(self, chunk: Chunk, model_name) -> bool:
  133. """Check if result already exists in db"""
  134. sql = """
  135. SELECT
  136. COUNT(*) AS num_messages
  137. FROM
  138. processed
  139. WHERE
  140. model_name = ?
  141. AND chunk_id = ?
  142. AND responsive IS NOT NULL
  143. """
  144. result = self.cursor.execute(sql, (model_name, chunk.chunk_id))
  145. row: dict = self.cursor.fetchone()
  146. if row and row["0"] == len(chunk.messages):
  147. return True
  148. return False
  149. def _save_result(self, chunk: Chunk, results: list[dict], model_name: str):
  150. """Save result to db"""
  151. # merge the chunk messages with the results
  152. merged_results = {}
  153. for msg in chunk.messages:
  154. merged_results[msg.line_number] = {"message": msg}
  155. for item in results:
  156. if item["message_index"] in merged_results:
  157. merged_results[item["message_index"]].update(item)
  158. else:
  159. merged_results[item["message_index"]] = item.copy()
  160. sql = """
  161. INSERT INTO processed (
  162. chunk_id,
  163. model_name,
  164. message_index,
  165. timestamp,
  166. sender,
  167. message,
  168. responsive,
  169. reason,
  170. criteria,
  171. confidence
  172. ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ? )
  173. ON CONFLICT (model_name, message_index) DO UPDATE SET
  174. chunk_id = excluded.chunk_id,
  175. timestamp = excluded.timestamp,
  176. sender = excluded.sender,
  177. message = excluded.message,
  178. responsive = excluded.responsive,
  179. reason = excluded.reason,
  180. criteria = excluded.criteria,
  181. confidence = excluded.confidence
  182. """
  183. for result in merged_results:
  184. msg = result.get("message", None)
  185. if msg is None or not isinstance(msg, Message):
  186. self.logger.error(
  187. f"somehow we have a result without a message: \n{result}"
  188. )
  189. continue
  190. self.cursor.execute(
  191. sql,
  192. (
  193. result.get("chunk_id", None),
  194. model_name,
  195. msg.line_number,
  196. msg.timestamp,
  197. msg.sender,
  198. msg.message,
  199. result.get("responsive", None),
  200. result.get("reason", None),
  201. result.get("criteria", None),
  202. ),
  203. )
  204. def _run_model_inference(
  205. self,
  206. chunks: List[Chunk],
  207. model_url: str,
  208. model_name: str,
  209. temperature: float,
  210. max_tokens: int,
  211. ):
  212. """Run inference on a single model"""
  213. system_prompt = self._create_system_prompt()
  214. response_format = self._create_response_format()
  215. success = 0
  216. errors = 0
  217. for chunk in tqdm(chunks, desc=f"{model_name} inference"):
  218. # check if this chunk has already been processed
  219. if self._check_existing_result(chunk, model_name):
  220. continue
  221. prompt_messages = []
  222. prompt_messages.append({"role": "system", "content": system_prompt})
  223. prompt_messages.append(
  224. {"role": "user", "content": self._create_user_prompt(chunk)}
  225. )
  226. payload = {
  227. "model": model_name,
  228. "messages": prompt_messages,
  229. "temperature": temperature,
  230. "max_tokens": max_tokens,
  231. "response_format": {
  232. "type": "json_schema",
  233. "json_schema": {
  234. "name": "structured_response",
  235. "schema": json.loads(response_format),
  236. },
  237. },
  238. }
  239. # "top_p",
  240. # "top_k",
  241. # "frequency_penalty",
  242. # "presence_penalty",
  243. # # "stop",
  244. # # "skip_special_tokens",
  245. # "enable_thinking",
  246. headers = {"Content-Type": "application/json"}
  247. response = "Not Processed"
  248. try:
  249. response = requests.post(
  250. f"{model_url}/v1/completions", headers=headers, json=payload
  251. )
  252. response.raise_for_status()
  253. # logger.log(LEVEL_TRACE, f"Response {response.status_code}\n{response.text}")
  254. data = response.json()
  255. if "error" in data:
  256. raise RuntimeError("LLM error")
  257. choices = data.get("choices", [])
  258. if not choices:
  259. raise KeyError("No choices found in response")
  260. first_choice = choices[0]
  261. if "message" in first_choice and first_choice["message"]:
  262. response_text = first_choice["message"].get("content", "")
  263. else:
  264. response_text = first_choice.get("text", "")
  265. if not response_text:
  266. raise ValueError("No response found")
  267. result = self._parse_response(response_text, chunk, model_name)
  268. if result:
  269. success += 1
  270. else:
  271. raise RuntimeError("Could not parse result")
  272. except Exception as e:
  273. self.logger.error(
  274. f"Error processing chunk {chunk.chunk_id}: \nResponse was:\n{response}\n{e.with_traceback}"
  275. )
  276. self._save_result(chunk, [], model_name)
  277. errors += 1
  278. return success, errors
  279. def _parse_response(
  280. self, response_text, chunk: Chunk, model_name: str
  281. ) -> list[dict]:
  282. """Parse model response"""
  283. parsed_list = {}
  284. try:
  285. parsed = loads(response_text)
  286. parsed_list = cast(List[Dict], parsed)
  287. except Exception as e:
  288. self.logger.error(f"Errror parsing response for chunk {chunk.chunk_id}")
  289. if not parsed_list:
  290. return []
  291. responses = []
  292. for result in parsed_list:
  293. try:
  294. responses.append(
  295. {
  296. "chunk_id": chunk.chunk_id,
  297. "message_index": result.get("message_index", None),
  298. "responsive": result.get("responsive", None),
  299. "reason": result.get("reason", ""),
  300. "criteria": result.get("criteria", []),
  301. "confidence": result.get("confidence", "low"),
  302. }
  303. )
  304. except Exception as e:
  305. self.logger.error(
  306. f"Error parsing response line: \n{e.with_traceback}\n{result}"
  307. )
  308. return responses
  309. if __name__ == "__main__":
  310. import argparse
  311. parser = argparse.ArgumentParser(description="Run dual Qwen inference")
  312. parser.add_argument("batch_name")
  313. parser.add_argument("--qwen3-url", default="http://localhost:8001")
  314. parser.add_argument("--qwen25-url", default="http://localhost:8002")
  315. parser.add_argument("--output-dir", default="./pipeline_output")
  316. args = parser.parse_args()
  317. runner = InferenceRunner(
  318. args.batch_name, args.qwen3_url, args.qwen25_url, args.output_dir
  319. )
  320. runner.run_inference()