| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- """
- Parallel inference runner for dual Qwen models with concurrent processing.
- """
- import json
- import requests
- from typing import List, Dict
- from pathlib import Path
- import logging
- from tqdm import tqdm
- from concurrent.futures import ThreadPoolExecutor, as_completed
- import time
- class ParallelInferenceRunner:
- """Run inference on dual Qwen models with parallel processing"""
-
- def __init__(self, qwen3_url: str = "http://localhost:8000",
- qwen25_url: str = "http://localhost:8001",
- output_dir: str = './pipeline_output',
- max_workers: int = 4):
- self.qwen3_url = qwen3_url
- self.qwen25_url = qwen25_url
- self.output_dir = Path(output_dir)
- self.max_workers = max_workers
-
- self.logger = logging.getLogger('ParallelInferenceRunner')
- self.logger.setLevel(logging.INFO)
-
- if not self.logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- self.logger.addHandler(handler)
-
- def load_requests(self, requests_file: str) -> List[Dict]:
- """Load inference requests from JSONL file"""
- requests_data = []
-
- with open(requests_file, 'r') as f:
- for line in f:
- requests_data.append(json.loads(line))
-
- self.logger.info(f"Loaded {len(requests_data)} inference requests")
- return requests_data
-
- def run_inference(self, requests_file: str,
- temperature: float = 0.1,
- max_tokens: int = 500):
- """
- Run parallel inference on both models.
-
- Args:
- requests_file: Path to inference requests JSONL
- temperature: Sampling temperature
- max_tokens: Maximum tokens to generate
- """
- self.logger.info("=" * 80)
- self.logger.info("RUNNING PARALLEL DUAL QWEN INFERENCE")
- self.logger.info("=" * 80)
- self.logger.info(f"Max workers: {self.max_workers}")
-
- # Load requests
- requests_data = self.load_requests(requests_file)
-
- # Run Qwen 3 235B (primary) in parallel
- self.logger.info("\nRunning Qwen 3 235B inference (parallel)...")
- start_time = time.time()
- qwen3_results = self._run_parallel_inference(
- requests_data,
- self.qwen3_url,
- "Qwen3-235B",
- temperature,
- max_tokens
- )
- qwen3_time = time.time() - start_time
- self.logger.info(f"Qwen 3 completed in {qwen3_time:.1f}s")
-
- # Save Qwen 3 results
- qwen3_file = self.output_dir / 'qwen3_results.jsonl'
- self._save_results(qwen3_results, qwen3_file)
-
- # Run Qwen 2.5 72B (secondary) in parallel
- self.logger.info("\nRunning Qwen 2.5 72B inference (parallel)...")
- start_time = time.time()
- qwen25_results = self._run_parallel_inference(
- requests_data,
- self.qwen25_url,
- "Qwen2.5-72B",
- temperature,
- max_tokens
- )
- qwen25_time = time.time() - start_time
- self.logger.info(f"Qwen 2.5 completed in {qwen25_time:.1f}s")
-
- # Save Qwen 2.5 results
- qwen25_file = self.output_dir / 'qwen25_results.jsonl'
- self._save_results(qwen25_results, qwen25_file)
-
- self.logger.info("\n" + "=" * 80)
- self.logger.info("PARALLEL INFERENCE COMPLETE")
- self.logger.info("=" * 80)
- self.logger.info(f"Qwen 3 time: {qwen3_time:.1f}s")
- self.logger.info(f"Qwen 2.5 time: {qwen25_time:.1f}s")
- self.logger.info(f"Total time: {qwen3_time + qwen25_time:.1f}s")
- self.logger.info(f"Speedup: {len(requests_data) * 2 / (qwen3_time + qwen25_time):.1f}x")
- self.logger.info(f"\nQwen 3 results: {qwen3_file}")
- self.logger.info(f"Qwen 2.5 results: {qwen25_file}")
-
- return str(qwen3_file), str(qwen25_file)
-
- def _run_parallel_inference(self, requests_data: List[Dict],
- model_url: str, model_name: str,
- temperature: float, max_tokens: int) -> List[Dict]:
- """Run inference on a single model with parallel workers"""
- results = [None] * len(requests_data)
-
- with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
- # Submit all tasks
- future_to_idx = {
- executor.submit(
- self._process_single_request,
- req, model_url, model_name, temperature, max_tokens
- ): idx
- for idx, req in enumerate(requests_data)
- }
-
- # Process completed tasks with progress bar
- with tqdm(total=len(requests_data), desc=f"{model_name}") as pbar:
- for future in as_completed(future_to_idx):
- idx = future_to_idx[future]
- try:
- result = future.result()
- results[idx] = result
- except Exception as e:
- self.logger.error(f"Error processing request {idx}: {e}")
- results[idx] = self._create_error_result(
- requests_data[idx], model_name
- )
- pbar.update(1)
-
- return results
-
- def _process_single_request(self, request: Dict, model_url: str,
- model_name: str, temperature: float,
- max_tokens: int) -> Dict:
- """Process a single inference request"""
- try:
- response = requests.post(
- f"{model_url}/v1/completions",
- json={
- 'prompt': request['prompt'],
- 'max_tokens': max_tokens,
- 'temperature': temperature
- },
- timeout=60
- )
-
- if response.status_code == 200:
- return self._parse_response(response.json(), request, model_name)
- else:
- return self._create_error_result(request, model_name)
-
- except Exception as e:
- self.logger.error(f"Exception for chunk {request['chunk_id']}: {e}")
- return self._create_error_result(request, model_name)
-
- def _parse_response(self, response: Dict, request: Dict,
- model_name: str) -> Dict:
- """Parse model response"""
- try:
- text = response['choices'][0]['text']
- parsed = json.loads(text)
-
- return {
- 'chunk_id': request['chunk_id'],
- 'responsive_line_numbers': parsed.get('responsive_line_numbers', []),
- 'reasoning': parsed.get('reasoning', ''),
- 'confidence': parsed.get('confidence', 'medium'),
- 'model_name': model_name
- }
- except Exception:
- return self._create_error_result(request, model_name)
-
- def _create_error_result(self, request: Dict, model_name: str) -> Dict:
- """Create error result"""
- return {
- 'chunk_id': request['chunk_id'],
- 'responsive_line_numbers': [],
- 'reasoning': 'Error during inference',
- 'confidence': 'low',
- 'model_name': model_name,
- 'error': True
- }
-
- def _save_results(self, results: List[Dict], filepath: Path):
- """Save results to JSONL"""
- with open(filepath, 'w') as f:
- for result in results:
- f.write(json.dumps(result) + '\n')
-
- self.logger.info(f"Saved {len(results)} results to {filepath}")
- if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser(description='Run parallel dual Qwen inference')
- parser.add_argument('requests_file', help='Path to inference requests JSONL')
- parser.add_argument('--qwen3-url', default='http://localhost:8000')
- parser.add_argument('--qwen25-url', default='http://localhost:8001')
- parser.add_argument('--output-dir', default='./pipeline_output')
- parser.add_argument('--max-workers', type=int, default=4,
- help='Number of parallel workers')
- parser.add_argument('--temperature', type=float, default=0.1)
- parser.add_argument('--max-tokens', type=int, default=500)
-
- args = parser.parse_args()
-
- runner = ParallelInferenceRunner(
- args.qwen3_url, args.qwen25_url, args.output_dir, args.max_workers
- )
- runner.run_inference(args.requests_file, args.temperature, args.max_tokens)
|