parallel_inference_runner.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. """
  2. Parallel inference runner for dual Qwen models with concurrent processing.
  3. """
  4. import json
  5. import requests
  6. from typing import List, Dict
  7. from pathlib import Path
  8. import logging
  9. from tqdm import tqdm
  10. from concurrent.futures import ThreadPoolExecutor, as_completed
  11. import time
  12. class ParallelInferenceRunner:
  13. """Run inference on dual Qwen models with parallel processing"""
  14. def __init__(self, qwen3_url: str = "http://localhost:8000",
  15. qwen25_url: str = "http://localhost:8001",
  16. output_dir: str = './pipeline_output',
  17. max_workers: int = 4):
  18. self.qwen3_url = qwen3_url
  19. self.qwen25_url = qwen25_url
  20. self.output_dir = Path(output_dir)
  21. self.max_workers = max_workers
  22. self.logger = logging.getLogger('ParallelInferenceRunner')
  23. self.logger.setLevel(logging.INFO)
  24. if not self.logger.handlers:
  25. handler = logging.StreamHandler()
  26. formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
  27. handler.setFormatter(formatter)
  28. self.logger.addHandler(handler)
  29. def load_requests(self, requests_file: str) -> List[Dict]:
  30. """Load inference requests from JSONL file"""
  31. requests_data = []
  32. with open(requests_file, 'r') as f:
  33. for line in f:
  34. requests_data.append(json.loads(line))
  35. self.logger.info(f"Loaded {len(requests_data)} inference requests")
  36. return requests_data
  37. def run_inference(self, requests_file: str,
  38. temperature: float = 0.1,
  39. max_tokens: int = 500):
  40. """
  41. Run parallel inference on both models.
  42. Args:
  43. requests_file: Path to inference requests JSONL
  44. temperature: Sampling temperature
  45. max_tokens: Maximum tokens to generate
  46. """
  47. self.logger.info("=" * 80)
  48. self.logger.info("RUNNING PARALLEL DUAL QWEN INFERENCE")
  49. self.logger.info("=" * 80)
  50. self.logger.info(f"Max workers: {self.max_workers}")
  51. # Load requests
  52. requests_data = self.load_requests(requests_file)
  53. # Run Qwen 3 235B (primary) in parallel
  54. self.logger.info("\nRunning Qwen 3 235B inference (parallel)...")
  55. start_time = time.time()
  56. qwen3_results = self._run_parallel_inference(
  57. requests_data,
  58. self.qwen3_url,
  59. "Qwen3-235B",
  60. temperature,
  61. max_tokens
  62. )
  63. qwen3_time = time.time() - start_time
  64. self.logger.info(f"Qwen 3 completed in {qwen3_time:.1f}s")
  65. # Save Qwen 3 results
  66. qwen3_file = self.output_dir / 'qwen3_results.jsonl'
  67. self._save_results(qwen3_results, qwen3_file)
  68. # Run Qwen 2.5 72B (secondary) in parallel
  69. self.logger.info("\nRunning Qwen 2.5 72B inference (parallel)...")
  70. start_time = time.time()
  71. qwen25_results = self._run_parallel_inference(
  72. requests_data,
  73. self.qwen25_url,
  74. "Qwen2.5-72B",
  75. temperature,
  76. max_tokens
  77. )
  78. qwen25_time = time.time() - start_time
  79. self.logger.info(f"Qwen 2.5 completed in {qwen25_time:.1f}s")
  80. # Save Qwen 2.5 results
  81. qwen25_file = self.output_dir / 'qwen25_results.jsonl'
  82. self._save_results(qwen25_results, qwen25_file)
  83. self.logger.info("\n" + "=" * 80)
  84. self.logger.info("PARALLEL INFERENCE COMPLETE")
  85. self.logger.info("=" * 80)
  86. self.logger.info(f"Qwen 3 time: {qwen3_time:.1f}s")
  87. self.logger.info(f"Qwen 2.5 time: {qwen25_time:.1f}s")
  88. self.logger.info(f"Total time: {qwen3_time + qwen25_time:.1f}s")
  89. self.logger.info(f"Speedup: {len(requests_data) * 2 / (qwen3_time + qwen25_time):.1f}x")
  90. self.logger.info(f"\nQwen 3 results: {qwen3_file}")
  91. self.logger.info(f"Qwen 2.5 results: {qwen25_file}")
  92. return str(qwen3_file), str(qwen25_file)
  93. def _run_parallel_inference(self, requests_data: List[Dict],
  94. model_url: str, model_name: str,
  95. temperature: float, max_tokens: int) -> List[Dict]:
  96. """Run inference on a single model with parallel workers"""
  97. results = [None] * len(requests_data)
  98. with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
  99. # Submit all tasks
  100. future_to_idx = {
  101. executor.submit(
  102. self._process_single_request,
  103. req, model_url, model_name, temperature, max_tokens
  104. ): idx
  105. for idx, req in enumerate(requests_data)
  106. }
  107. # Process completed tasks with progress bar
  108. with tqdm(total=len(requests_data), desc=f"{model_name}") as pbar:
  109. for future in as_completed(future_to_idx):
  110. idx = future_to_idx[future]
  111. try:
  112. result = future.result()
  113. results[idx] = result
  114. except Exception as e:
  115. self.logger.error(f"Error processing request {idx}: {e}")
  116. results[idx] = self._create_error_result(
  117. requests_data[idx], model_name
  118. )
  119. pbar.update(1)
  120. return results
  121. def _process_single_request(self, request: Dict, model_url: str,
  122. model_name: str, temperature: float,
  123. max_tokens: int) -> Dict:
  124. """Process a single inference request"""
  125. try:
  126. response = requests.post(
  127. f"{model_url}/v1/completions",
  128. json={
  129. 'prompt': request['prompt'],
  130. 'max_tokens': max_tokens,
  131. 'temperature': temperature
  132. },
  133. timeout=60
  134. )
  135. if response.status_code == 200:
  136. return self._parse_response(response.json(), request, model_name)
  137. else:
  138. return self._create_error_result(request, model_name)
  139. except Exception as e:
  140. self.logger.error(f"Exception for chunk {request['chunk_id']}: {e}")
  141. return self._create_error_result(request, model_name)
  142. def _parse_response(self, response: Dict, request: Dict,
  143. model_name: str) -> Dict:
  144. """Parse model response"""
  145. try:
  146. text = response['choices'][0]['text']
  147. parsed = json.loads(text)
  148. return {
  149. 'chunk_id': request['chunk_id'],
  150. 'responsive_line_numbers': parsed.get('responsive_line_numbers', []),
  151. 'reasoning': parsed.get('reasoning', ''),
  152. 'confidence': parsed.get('confidence', 'medium'),
  153. 'model_name': model_name
  154. }
  155. except Exception:
  156. return self._create_error_result(request, model_name)
  157. def _create_error_result(self, request: Dict, model_name: str) -> Dict:
  158. """Create error result"""
  159. return {
  160. 'chunk_id': request['chunk_id'],
  161. 'responsive_line_numbers': [],
  162. 'reasoning': 'Error during inference',
  163. 'confidence': 'low',
  164. 'model_name': model_name,
  165. 'error': True
  166. }
  167. def _save_results(self, results: List[Dict], filepath: Path):
  168. """Save results to JSONL"""
  169. with open(filepath, 'w') as f:
  170. for result in results:
  171. f.write(json.dumps(result) + '\n')
  172. self.logger.info(f"Saved {len(results)} results to {filepath}")
  173. if __name__ == "__main__":
  174. import argparse
  175. parser = argparse.ArgumentParser(description='Run parallel dual Qwen inference')
  176. parser.add_argument('requests_file', help='Path to inference requests JSONL')
  177. parser.add_argument('--qwen3-url', default='http://localhost:8000')
  178. parser.add_argument('--qwen25-url', default='http://localhost:8001')
  179. parser.add_argument('--output-dir', default='./pipeline_output')
  180. parser.add_argument('--max-workers', type=int, default=4,
  181. help='Number of parallel workers')
  182. parser.add_argument('--temperature', type=float, default=0.1)
  183. parser.add_argument('--max-tokens', type=int, default=500)
  184. args = parser.parse_args()
  185. runner = ParallelInferenceRunner(
  186. args.qwen3_url, args.qwen25_url, args.output_dir, args.max_workers
  187. )
  188. runner.run_inference(args.requests_file, args.temperature, args.max_tokens)