inference_runner.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. """
  2. Improved Inference runner for dual Qwen models with parallel processing.
  3. """
  4. import json
  5. import os
  6. import traceback
  7. import pandas as pd
  8. import requests
  9. from typing import List, Dict, cast
  10. from pathlib import Path
  11. import logging
  12. from tqdm import tqdm
  13. import sqlite3
  14. from json_repair import loads
  15. import concurrent.futures
  16. import threading
  17. from pipeline.common_defs import Chunk, Message
  18. class InferenceRunner:
  19. """Run inference on dual Qwen models with parallel processing"""
  20. def __init__(
  21. self,
  22. batch_name: str,
  23. qwen3_model_name: str = "",
  24. qwen3_url: str = "http://localhost:8001",
  25. qwen25_model_name: str = "",
  26. qwen25_url: str = "http://localhost:8002",
  27. output_dir: str = "./pipeline_output",
  28. max_workers: int = 4, # New parameter for parallel processing
  29. batch_size: int = 4,
  30. temperature: float = 0.01,
  31. top_p: float = 0.95,
  32. max_tokens: int = 8192,
  33. ):
  34. self.batch_name = batch_name
  35. self.qwen3_model_name = qwen3_model_name
  36. self.qwen3_url = qwen3_url
  37. self.qwen25_model_name = qwen25_model_name
  38. self.qwen25_url = qwen25_url
  39. self.output_dir = Path(output_dir)
  40. self.max_workers = max_workers
  41. self.batch_size = batch_size
  42. self.max_tokens = max_tokens
  43. self.temperature = temperature
  44. self.top_p = top_p
  45. self.logger = logging.getLogger("InferenceRunner")
  46. self.logger.setLevel(logging.DEBUG)
  47. if not self.logger.handlers:
  48. handler = logging.StreamHandler()
  49. formatter = logging.Formatter(
  50. "%(levelname)s - [%(filename)s:%(lineno)d] %(message)s"
  51. )
  52. handler.setFormatter(formatter)
  53. self.logger.addHandler(handler)
  54. # Create output directory if it doesn't exist
  55. self.output_dir.mkdir(parents=True, exist_ok=True)
  56. self.db = sqlite3.connect(
  57. self.output_dir / f"{batch_name}_processed.db3",
  58. check_same_thread=False,
  59. )
  60. self.db.row_factory = sqlite3.Row
  61. self.cursor = self.db.cursor()
  62. # Fixed SQL syntax error (removed colon after TEXT)
  63. sql = """
  64. CREATE TABLE IF NOT EXISTS processed (
  65. chunk_id INTEGER,
  66. model_name TEXT,
  67. message_index INTEGER,
  68. timestamp DATETIME,
  69. sender TEXT,
  70. message TEXT,
  71. responsive BOOLEAN,
  72. reason TEXT,
  73. criteria TEXT,
  74. confidence TEXT,
  75. PRIMARY KEY (model_name, message_index)
  76. );
  77. """
  78. self.cursor.execute(sql)
  79. self.db.commit() # Added commit
  80. self.logger.info("Summary database initialized")
  81. # Thread lock for database operations
  82. self.db_lock = threading.Lock()
  83. def _create_user_prompt(self, chunk: Chunk) -> str:
  84. """Create inference request for a chunk"""
  85. # Format messages
  86. messages_text = ""
  87. for msg in chunk.messages:
  88. messages_text += f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message_normalized}\n"
  89. # Create full prompt
  90. prompt = f"""
  91. Review and classify the following messages.
  92. MESSAGES TO REVIEW (Lines {chunk.start_line}-{chunk.end_line}):
  93. {messages_text}
  94. Provide your response as valid JSON following the specified format.
  95. """
  96. return prompt
  97. def _create_chunks(self) -> list[Chunk]:
  98. """Create chunks from preprocessed data"""
  99. chunks_file = self.output_dir / "chunks.json"
  100. with open(chunks_file, "r") as f:
  101. chunk_data = json.load(f)
  102. msg_file = self.output_dir / "preprocessed_messages.csv"
  103. msg_df = pd.read_csv(msg_file)
  104. # Reconstruct chunks
  105. chunks = []
  106. for item in chunk_data[:2]: # First 10 for testing
  107. chunk = Chunk(
  108. chunk_id=item["chunk_id"],
  109. start_line=item["start_line"],
  110. end_line=item["end_line"],
  111. messages=[],
  112. combined_text="",
  113. timestamp_start=item["timestamp_start"],
  114. timestamp_end=item.get("timestamp_end", ""),
  115. )
  116. chunk_messages = []
  117. dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"]]
  118. # index,timestamp,sender,message,line_number,message_normalized
  119. for row in dfRange.itertuples():
  120. message = Message(
  121. line_number=row[4],
  122. timestamp=row[1],
  123. sender=row[2],
  124. message=row[3],
  125. message_normalized=row[5],
  126. )
  127. chunk_messages.append(message)
  128. chunk.messages = chunk_messages
  129. chunks.append(chunk)
  130. return chunks
  131. def run_inference(self):
  132. """Run inference on both models with parallel processing"""
  133. self.logger.info("=" * 80)
  134. self.logger.info("RUNNING DUAL QWEN INFERENCE WITH PARALLEL PROCESSING")
  135. self.logger.info("=" * 80)
  136. chunks = self._create_chunks()
  137. if not chunks:
  138. self.logger.error("No chunks found. Exiting.")
  139. return
  140. # Run both models in parallel
  141. with concurrent.futures.ThreadPoolExecutor(
  142. max_workers=self.max_workers
  143. ) as executor:
  144. # Submit both model inference tasks
  145. qwen3_future = executor.submit(
  146. self._run_model_inference,
  147. chunks,
  148. self.qwen3_url,
  149. self.qwen3_model_name,
  150. )
  151. # qwen25_future = executor.submit(
  152. # self._run_model_inference,
  153. # chunks,
  154. # self.qwen25_url,
  155. # self.qwen25_model_name,
  156. # )
  157. # Wait for both to complete
  158. qwen3_success, qwen3_errors = qwen3_future.result()
  159. # qwen25_success, qwen25_errors = qwen25_future.result()
  160. self.logger.info(
  161. f"\nQwen3-235B Results: {qwen3_success} success, {qwen3_errors} errors"
  162. )
  163. # self.logger.info(
  164. # f"Qwen2.5-72B Results: {qwen25_success} success, {qwen25_errors} errors"
  165. # )
  166. self.logger.info("\n" + "=" * 80)
  167. self.logger.info("INFERENCE COMPLETE")
  168. self.logger.info("=" * 80)
  169. def _create_system_prompt(self) -> str:
  170. """Create system prompt for LLM"""
  171. prompt_file = Path(self.output_dir, "system_prompt.txt")
  172. with prompt_file.open("r") as file:
  173. prompt = file.read()
  174. return prompt
  175. def _create_response_format(self) -> str:
  176. """Create response format for LLM"""
  177. format_file = Path(self.output_dir, "response_format.json")
  178. with format_file.open("r") as file:
  179. response_format = file.read()
  180. return response_format
  181. def _check_existing_result(self, chunk: Chunk, model_name: str) -> bool:
  182. """Check if result already exists in db"""
  183. with self.db_lock:
  184. sql = """
  185. SELECT
  186. COUNT(*) AS num_messages
  187. FROM
  188. processed
  189. WHERE
  190. model_name = ?
  191. AND chunk_id = ?
  192. AND responsive IS NOT NULL
  193. """
  194. result = self.cursor.execute(sql, (model_name, chunk.chunk_id))
  195. row = self.cursor.fetchone()
  196. if row and row["num_messages"] == len(chunk.messages):
  197. return True
  198. return False
  199. def _save_result(self, chunk: Chunk, results: list[dict], model_name: str):
  200. """Save result to db"""
  201. # Merge the chunk messages with the results
  202. merged_results = {}
  203. for msg in chunk.messages:
  204. merged_results[msg.line_number] = {"message": msg}
  205. for item in results:
  206. if item["message_index"] in merged_results.keys():
  207. merged_results[item["message_index"]].update(item)
  208. else:
  209. merged_results[item["message_index"]] = item.copy()
  210. sql = """
  211. INSERT INTO processed (
  212. chunk_id,
  213. model_name,
  214. message_index,
  215. timestamp,
  216. sender,
  217. message,
  218. responsive,
  219. reason,
  220. criteria,
  221. confidence
  222. ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ? )
  223. ON CONFLICT (model_name, message_index) DO UPDATE SET
  224. chunk_id = excluded.chunk_id,
  225. timestamp = excluded.timestamp,
  226. sender = excluded.sender,
  227. message = excluded.message,
  228. responsive = excluded.responsive,
  229. reason = excluded.reason,
  230. criteria = excluded.criteria,
  231. confidence = excluded.confidence
  232. """
  233. with self.db_lock:
  234. for msg_idx, result_dict in merged_results.items():
  235. msg = result_dict.get("message", None)
  236. if msg is None:
  237. self.logger.error(
  238. f"Result without a message for line {msg_idx}: \n{result_dict}"
  239. )
  240. continue
  241. elif not isinstance(msg, Message):
  242. self.logger.error(
  243. f"Message not a Message for line {msg_idx}: \n{result_dict}"
  244. )
  245. continue
  246. self.cursor.execute(
  247. sql,
  248. (
  249. result_dict.get("chunk_id", chunk.chunk_id),
  250. model_name,
  251. msg.line_number,
  252. msg.timestamp,
  253. msg.sender,
  254. msg.message,
  255. result_dict.get("responsive", None),
  256. result_dict.get("reason", None),
  257. str(result_dict.get("criteria", [])), # Convert to string
  258. result_dict.get("confidence", None),
  259. ),
  260. )
  261. self.db.commit()
  262. def _process_chunk_batch(
  263. self,
  264. chunk_batch: List[Chunk],
  265. model_url: str,
  266. model_name: str,
  267. system_prompt: str,
  268. response_format: str,
  269. ) -> tuple[int, int]:
  270. """Process a batch of chunks"""
  271. success = 0
  272. errors = 0
  273. for chunk in chunk_batch:
  274. # Check if this chunk has already been processed
  275. if self._check_existing_result(chunk, model_name):
  276. success += 1
  277. continue
  278. prompt_messages = []
  279. prompt_messages.append({"role": "system", "content": system_prompt})
  280. prompt_messages.append(
  281. {"role": "user", "content": self._create_user_prompt(chunk)}
  282. )
  283. payload = {
  284. "model": model_name,
  285. "messages": prompt_messages,
  286. "temperature": self.temperature,
  287. "top_p": self.top_p,
  288. "max_tokens": self.max_tokens,
  289. "response_format": {
  290. "type": "json_schema",
  291. "json_schema": {
  292. "name": "structured_response",
  293. "schema": json.loads(response_format),
  294. },
  295. },
  296. }
  297. self.logger.debug(f"Payload:\n{str(payload)[:200]}...")
  298. headers = {"Content-Type": "application/json"}
  299. if os.getenv("AI_TOKEN"):
  300. headers["Authorization"] = f"Bearer {os.getenv('AI_TOKEN')}"
  301. response = None
  302. try:
  303. response = requests.post(
  304. f"{model_url}/v1/chat/completions",
  305. headers=headers,
  306. json=payload,
  307. timeout=600,
  308. )
  309. response.raise_for_status()
  310. data = response.json()
  311. if "error" in data:
  312. raise RuntimeError(f"LLM error: {data['error']}")
  313. choices = data.get("choices", [])
  314. if not choices:
  315. raise KeyError("No choices found in response")
  316. first_choice = choices[0]
  317. if "message" in first_choice and first_choice["message"]:
  318. response_text = first_choice["message"].get("content", "")
  319. else:
  320. response_text = first_choice.get("text", "")
  321. if not response_text:
  322. raise ValueError("No response found")
  323. result = self._parse_response(response_text, chunk, model_name)
  324. if result:
  325. self._save_result(chunk, result, model_name)
  326. success += 1
  327. else:
  328. raise RuntimeError("Could not parse result")
  329. except Exception as e:
  330. self.logger.error(
  331. f"Error processing chunk {chunk.chunk_id} with {model_name}: {str(e)}\n{traceback.format_exc()}"
  332. )
  333. if response:
  334. self.logger.error(f"Response status: {response.status_code}")
  335. self.logger.error(f"Response text: {response.text[:500]}...")
  336. self._save_result(chunk, [], model_name)
  337. errors += 1
  338. return success, errors
  339. def _run_model_inference(
  340. self,
  341. chunks: List[Chunk],
  342. model_url: str,
  343. model_name: str,
  344. ) -> tuple[int, int]:
  345. """Run inference on a single model with parallel processing"""
  346. system_prompt = self._create_system_prompt()
  347. response_format = self._create_response_format()
  348. total_success = 0
  349. total_errors = 0
  350. # Split chunks into batches of batch_size
  351. chunk_batches = [
  352. chunks[i : i + self.batch_size]
  353. for i in range(0, len(chunks), self.batch_size)
  354. ]
  355. self.logger.info(
  356. f"Processing {len(chunks)} chunks in {len(chunk_batches)} batches of up to {self.batch_size} chunks each for {model_name}"
  357. )
  358. # Process batches in parallel with max_workers threads
  359. with concurrent.futures.ThreadPoolExecutor(
  360. max_workers=self.max_workers
  361. ) as executor:
  362. # Submit all batch processing tasks
  363. future_to_batch = {}
  364. for i, batch in enumerate(chunk_batches):
  365. future = executor.submit(
  366. self._process_chunk_batch,
  367. batch,
  368. model_url,
  369. model_name,
  370. system_prompt,
  371. response_format,
  372. )
  373. future_to_batch[future] = i
  374. # Process completed batches with progress bar
  375. with tqdm(total=len(chunk_batches), desc=f"{model_name} batches") as pbar:
  376. for future in concurrent.futures.as_completed(future_to_batch):
  377. batch_idx = future_to_batch[future]
  378. try:
  379. success, errors = future.result()
  380. total_success += success
  381. total_errors += errors
  382. pbar.set_postfix(
  383. {
  384. "success": total_success,
  385. "errors": total_errors,
  386. "batch": f"{batch_idx + 1}/{len(chunk_batches)}",
  387. }
  388. )
  389. except Exception as e:
  390. self.logger.error(
  391. f"Batch {batch_idx} failed: {str(e)}\n{traceback.format_exc()}"
  392. )
  393. total_errors += len(chunk_batches[batch_idx])
  394. finally:
  395. pbar.update(1)
  396. return total_success, total_errors
  397. def _parse_response(
  398. self, response_text: str, chunk: Chunk, model_name: str
  399. ) -> list[dict]:
  400. """Parse model response"""
  401. try:
  402. parsed = loads(response_text)
  403. # Handle both list and dict responses
  404. if isinstance(parsed, dict):
  405. parsed_list = [parsed]
  406. else:
  407. parsed_list = cast(List[Dict], parsed)
  408. except Exception as e:
  409. self.logger.error(
  410. f"Error parsing response for chunk {chunk.chunk_id}: {str(e)}"
  411. )
  412. return []
  413. if not parsed_list:
  414. return []
  415. responses = []
  416. for result in parsed_list:
  417. try:
  418. responses.append(
  419. {
  420. "chunk_id": chunk.chunk_id,
  421. "message_index": result.get("message_index", None),
  422. "responsive": result.get("responsive", None),
  423. "reason": result.get("reason", ""),
  424. "criteria": result.get("criteria", []),
  425. "confidence": result.get("confidence", "low"),
  426. }
  427. )
  428. except Exception as e:
  429. self.logger.error(f"Error parsing response line: {str(e)}\n{result}")
  430. return responses
  431. def __del__(self):
  432. """Clean up database connection"""
  433. if hasattr(self, "db"):
  434. self.db.close()
  435. if __name__ == "__main__":
  436. import argparse
  437. parser = argparse.ArgumentParser(
  438. description="Run dual Qwen inference with parallel processing"
  439. )
  440. parser.add_argument("batch_name")
  441. parser.add_argument(
  442. "--qwen3-model-name", default="Qwen/Qwen3-30B-A3B-Instruct-2507-FP8"
  443. )
  444. parser.add_argument("--qwen3-url", default="http://localhost:8001")
  445. parser.add_argument("--qwen25-model-name", default="Qwen/Qwen2.5-72B-Instruct-AWQ")
  446. parser.add_argument("--qwen25-url", default="http://localhost:8002")
  447. parser.add_argument("--output-dir", default="./pipeline_output")
  448. parser.add_argument(
  449. "--max-workers",
  450. type=int,
  451. default=4,
  452. help="Maximum number of parallel workers per model",
  453. )
  454. parser.add_argument(
  455. "--batch-size",
  456. type=int,
  457. default=4,
  458. help="Number of chunks in each batch",
  459. )
  460. parser.add_argument(
  461. "--temperature",
  462. type=float,
  463. default=0.01,
  464. help="LLM Temperature setting",
  465. )
  466. parser.add_argument(
  467. "--top-p",
  468. type=float,
  469. default=0.95,
  470. help="LLM Top P setting",
  471. )
  472. parser.add_argument(
  473. "--max-tokens",
  474. type=int,
  475. default=8192,
  476. help="Maximum number of tokens to use for each batch",
  477. )
  478. args = parser.parse_args()
  479. runner = InferenceRunner(
  480. batch_name=args.batch_name,
  481. qwen3_model_name=args.qwen3_model_name,
  482. qwen3_url=args.qwen3_url,
  483. qwen25_model_name=args.qwen25_model_name,
  484. qwen25_url=args.qwen25_url,
  485. output_dir=args.output_dir,
  486. max_workers=args.max_workers,
  487. batch_size=args.batch_size,
  488. temperature=args.temperature,
  489. top_p=args.top_p,
  490. max_tokens=args.max_tokens,
  491. )
  492. runner.run_inference()