""" Parallel inference runner for dual Qwen models with concurrent processing. """ import json import requests from typing import List, Dict from pathlib import Path import logging from tqdm import tqdm from concurrent.futures import ThreadPoolExecutor, as_completed import time class ParallelInferenceRunner: """Run inference on dual Qwen models with parallel processing""" def __init__(self, qwen3_url: str = "http://localhost:8000", qwen25_url: str = "http://localhost:8001", output_dir: str = './pipeline_output', max_workers: int = 4): self.qwen3_url = qwen3_url self.qwen25_url = qwen25_url self.output_dir = Path(output_dir) self.max_workers = max_workers self.logger = logging.getLogger('ParallelInferenceRunner') self.logger.setLevel(logging.INFO) if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) self.logger.addHandler(handler) def load_requests(self, requests_file: str) -> List[Dict]: """Load inference requests from JSONL file""" requests_data = [] with open(requests_file, 'r') as f: for line in f: requests_data.append(json.loads(line)) self.logger.info(f"Loaded {len(requests_data)} inference requests") return requests_data def run_inference(self, requests_file: str, temperature: float = 0.1, max_tokens: int = 500): """ Run parallel inference on both models. Args: requests_file: Path to inference requests JSONL temperature: Sampling temperature max_tokens: Maximum tokens to generate """ self.logger.info("=" * 80) self.logger.info("RUNNING PARALLEL DUAL QWEN INFERENCE") self.logger.info("=" * 80) self.logger.info(f"Max workers: {self.max_workers}") # Load requests requests_data = self.load_requests(requests_file) # Run Qwen 3 235B (primary) in parallel self.logger.info("\nRunning Qwen 3 235B inference (parallel)...") start_time = time.time() qwen3_results = self._run_parallel_inference( requests_data, self.qwen3_url, "Qwen3-235B", temperature, max_tokens ) qwen3_time = time.time() - start_time self.logger.info(f"Qwen 3 completed in {qwen3_time:.1f}s") # Save Qwen 3 results qwen3_file = self.output_dir / 'qwen3_results.jsonl' self._save_results(qwen3_results, qwen3_file) # Run Qwen 2.5 72B (secondary) in parallel self.logger.info("\nRunning Qwen 2.5 72B inference (parallel)...") start_time = time.time() qwen25_results = self._run_parallel_inference( requests_data, self.qwen25_url, "Qwen2.5-72B", temperature, max_tokens ) qwen25_time = time.time() - start_time self.logger.info(f"Qwen 2.5 completed in {qwen25_time:.1f}s") # Save Qwen 2.5 results qwen25_file = self.output_dir / 'qwen25_results.jsonl' self._save_results(qwen25_results, qwen25_file) self.logger.info("\n" + "=" * 80) self.logger.info("PARALLEL INFERENCE COMPLETE") self.logger.info("=" * 80) self.logger.info(f"Qwen 3 time: {qwen3_time:.1f}s") self.logger.info(f"Qwen 2.5 time: {qwen25_time:.1f}s") self.logger.info(f"Total time: {qwen3_time + qwen25_time:.1f}s") self.logger.info(f"Speedup: {len(requests_data) * 2 / (qwen3_time + qwen25_time):.1f}x") self.logger.info(f"\nQwen 3 results: {qwen3_file}") self.logger.info(f"Qwen 2.5 results: {qwen25_file}") return str(qwen3_file), str(qwen25_file) def _run_parallel_inference(self, requests_data: List[Dict], model_url: str, model_name: str, temperature: float, max_tokens: int) -> List[Dict]: """Run inference on a single model with parallel workers""" results = [None] * len(requests_data) with ThreadPoolExecutor(max_workers=self.max_workers) as executor: # Submit all tasks future_to_idx = { executor.submit( self._process_single_request, req, model_url, model_name, temperature, max_tokens ): idx for idx, req in enumerate(requests_data) } # Process completed tasks with progress bar with tqdm(total=len(requests_data), desc=f"{model_name}") as pbar: for future in as_completed(future_to_idx): idx = future_to_idx[future] try: result = future.result() results[idx] = result except Exception as e: self.logger.error(f"Error processing request {idx}: {e}") results[idx] = self._create_error_result( requests_data[idx], model_name ) pbar.update(1) return results def _process_single_request(self, request: Dict, model_url: str, model_name: str, temperature: float, max_tokens: int) -> Dict: """Process a single inference request""" try: response = requests.post( f"{model_url}/v1/completions", json={ 'prompt': request['prompt'], 'max_tokens': max_tokens, 'temperature': temperature }, timeout=60 ) if response.status_code == 200: return self._parse_response(response.json(), request, model_name) else: return self._create_error_result(request, model_name) except Exception as e: self.logger.error(f"Exception for chunk {request['chunk_id']}: {e}") return self._create_error_result(request, model_name) def _parse_response(self, response: Dict, request: Dict, model_name: str) -> Dict: """Parse model response""" try: text = response['choices'][0]['text'] parsed = json.loads(text) return { 'chunk_id': request['chunk_id'], 'responsive_line_numbers': parsed.get('responsive_line_numbers', []), 'reasoning': parsed.get('reasoning', ''), 'confidence': parsed.get('confidence', 'medium'), 'model_name': model_name } except Exception: return self._create_error_result(request, model_name) def _create_error_result(self, request: Dict, model_name: str) -> Dict: """Create error result""" return { 'chunk_id': request['chunk_id'], 'responsive_line_numbers': [], 'reasoning': 'Error during inference', 'confidence': 'low', 'model_name': model_name, 'error': True } def _save_results(self, results: List[Dict], filepath: Path): """Save results to JSONL""" with open(filepath, 'w') as f: for result in results: f.write(json.dumps(result) + '\n') self.logger.info(f"Saved {len(results)} results to {filepath}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Run parallel dual Qwen inference') parser.add_argument('requests_file', help='Path to inference requests JSONL') parser.add_argument('--qwen3-url', default='http://localhost:8000') parser.add_argument('--qwen25-url', default='http://localhost:8001') parser.add_argument('--output-dir', default='./pipeline_output') parser.add_argument('--max-workers', type=int, default=4, help='Number of parallel workers') parser.add_argument('--temperature', type=float, default=0.1) parser.add_argument('--max-tokens', type=int, default=500) args = parser.parse_args() runner = ParallelInferenceRunner( args.qwen3_url, args.qwen25_url, args.output_dir, args.max_workers ) runner.run_inference(args.requests_file, args.temperature, args.max_tokens)