inference_runner.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """
  2. Inference runner for dual Qwen models.
  3. """
  4. import json
  5. import requests
  6. from typing import List, Dict
  7. from pathlib import Path
  8. import logging
  9. from tqdm import tqdm
  10. class InferenceRunner:
  11. """Run inference on dual Qwen models"""
  12. def __init__(self, qwen3_url: str = "http://localhost:8000",
  13. qwen25_url: str = "http://localhost:8001",
  14. output_dir: str = "./pipeline_output"):
  15. self.qwen3_url = qwen3_url
  16. self.qwen25_url = qwen25_url
  17. self.output_dir = Path(output_dir)
  18. self.logger = logging.getLogger("InferenceRunner")
  19. self.logger.setLevel(logging.INFO)
  20. if not self.logger.handlers:
  21. handler = logging.StreamHandler()
  22. formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
  23. handler.setFormatter(formatter)
  24. self.logger.addHandler(handler)
  25. def load_requests(self, requests_file: str) -> List[Dict]:
  26. """Load inference requests from JSONL file"""
  27. requests_data = []
  28. with open(requests_file, "r") as f:
  29. for line in f:
  30. requests_data.append(json.loads(line))
  31. self.logger.info(f"Loaded {len(requests_data)} inference requests")
  32. return requests_data
  33. def run_inference(self, requests_file: str,
  34. temperature: float = 0.1,
  35. max_tokens: int = 500):
  36. """Run inference on both models"""
  37. self.logger.info("=" * 80)
  38. self.logger.info("RUNNING DUAL QWEN INFERENCE")
  39. self.logger.info("=" * 80)
  40. requests_data = self.load_requests(requests_file)
  41. self.logger.info("\nRunning Qwen 3 235B inference...")
  42. qwen3_results = self._run_model_inference(
  43. requests_data, self.qwen3_url, "Qwen3-235B", temperature, max_tokens
  44. )
  45. qwen3_file = self.output_dir / "qwen3_results.jsonl"
  46. self._save_results(qwen3_results, qwen3_file)
  47. self.logger.info("\nRunning Qwen 2.5 72B inference...")
  48. qwen25_results = self._run_model_inference(
  49. requests_data, self.qwen25_url, "Qwen2.5-72B", temperature, max_tokens
  50. )
  51. qwen25_file = self.output_dir / "qwen25_results.jsonl"
  52. self._save_results(qwen25_results, qwen25_file)
  53. self.logger.info("\n" + "=" * 80)
  54. self.logger.info("INFERENCE COMPLETE")
  55. self.logger.info("=" * 80)
  56. return str(qwen3_file), str(qwen25_file)
  57. def _run_model_inference(self, requests_data: List[Dict],
  58. model_url: str, model_name: str,
  59. temperature: float, max_tokens: int) -> List[Dict]:
  60. """Run inference on a single model"""
  61. results = []
  62. for req in tqdm(requests_data, desc=f"{model_name} inference"):
  63. try:
  64. response = requests.post(
  65. f"{model_url}/v1/completions",
  66. json={
  67. "prompt": req["prompt"],
  68. "max_tokens": max_tokens,
  69. "temperature": temperature
  70. },
  71. timeout=60
  72. )
  73. if response.status_code == 200:
  74. result = self._parse_response(response.json(), req, model_name)
  75. results.append(result)
  76. else:
  77. results.append(self._create_error_result(req, model_name))
  78. except Exception as e:
  79. self.logger.error(f"Exception for chunk {req['chunk_id']}: {e}")
  80. results.append(self._create_error_result(req, model_name))
  81. return results
  82. def _parse_response(self, response: Dict, request: Dict, model_name: str) -> Dict:
  83. """Parse model response"""
  84. try:
  85. text = response["choices"][0]["text"]
  86. parsed = json.loads(text)
  87. return {
  88. "chunk_id": request["chunk_id"],
  89. "responsive_line_numbers": parsed.get("responsive_line_numbers", []),
  90. "reasoning": parsed.get("reasoning", ""),
  91. "confidence": parsed.get("confidence", "medium"),
  92. "model_name": model_name
  93. }
  94. except Exception:
  95. return self._create_error_result(request, model_name)
  96. def _create_error_result(self, request: Dict, model_name: str) -> Dict:
  97. """Create error result"""
  98. return {
  99. "chunk_id": request["chunk_id"],
  100. "responsive_line_numbers": [],
  101. "reasoning": "Error during inference",
  102. "confidence": "low",
  103. "model_name": model_name,
  104. "error": True
  105. }
  106. def _save_results(self, results: List[Dict], filepath: Path):
  107. """Save results to JSONL"""
  108. with open(filepath, "w") as f:
  109. for result in results:
  110. f.write(json.dumps(result) + "\n")
  111. self.logger.info(f"Saved {len(results)} results to {filepath}")
  112. if __name__ == "__main__":
  113. import argparse
  114. parser = argparse.ArgumentParser(description="Run dual Qwen inference")
  115. parser.add_argument("requests_file", help="Path to inference requests JSONL")
  116. parser.add_argument("--qwen3-url", default="http://localhost:8000")
  117. parser.add_argument("--qwen25-url", default="http://localhost:8001")
  118. parser.add_argument("--output-dir", default="./pipeline_output")
  119. args = parser.parse_args()
  120. runner = InferenceRunner(args.qwen3_url, args.qwen25_url, args.output_dir)
  121. runner.run_inference(args.requests_file)