Kaynağa Gözat

parallel universe

adri 4 hafta önce
ebeveyn
işleme
bb1d2edadf
1 değiştirilmiş dosya ile 299 ekleme ve 107 silme
  1. 299 107
      pipeline/utils/inference_runner.py

+ 299 - 107
pipeline/utils/inference_runner.py

@@ -1,37 +1,55 @@
 """
-Inference runner for dual Qwen models.
+Improved Inference runner for dual Qwen models with parallel processing.
 """
 
-from collections import defaultdict
 import json
+import os
+import traceback
 import pandas as pd
 import requests
-from typing import List, Dict, TypedDict, cast
+from typing import List, Dict, cast
 from pathlib import Path
 import logging
 from tqdm import tqdm
 import sqlite3
-from json_repair import repair_json, loads
+from json_repair import loads
+import concurrent.futures
+import threading
 
 from pipeline.common_defs import Chunk, Message
 
+
 class InferenceRunner:
-    """Run inference on dual Qwen models"""
+    """Run inference on dual Qwen models with parallel processing"""
 
     def __init__(
         self,
         batch_name: str,
-        qwen3_url: str = "http://localhost:8000",
-        qwen25_url: str = "http://localhost:8001",
+        qwen3_model_name: str = "",
+        qwen3_url: str = "http://localhost:8001",
+        qwen25_model_name: str = "",
+        qwen25_url: str = "http://localhost:8002",
         output_dir: str = "./pipeline_output",
+        max_workers: int = 4,  # New parameter for parallel processing
+        batch_size: int = 4,
+        temperature: float = 0.01,
+        top_p: float = 0.95,
+        max_tokens: int = 8192,
     ):
         self.batch_name = batch_name
+        self.qwen3_model_name = qwen3_model_name
         self.qwen3_url = qwen3_url
+        self.qwen25_model_name = qwen25_model_name
         self.qwen25_url = qwen25_url
         self.output_dir = Path(output_dir)
+        self.max_workers = max_workers
+        self.batch_size = batch_size
+        self.max_tokens = max_tokens
+        self.temperature = temperature
+        self.top_p = top_p
 
         self.logger = logging.getLogger("InferenceRunner")
-        self.logger.setLevel(logging.INFO)
+        self.logger.setLevel(logging.DEBUG)
 
         if not self.logger.handlers:
             handler = logging.StreamHandler()
@@ -41,18 +59,24 @@ class InferenceRunner:
             handler.setFormatter(formatter)
             self.logger.addHandler(handler)
 
+        # Create output directory if it doesn't exist
+        self.output_dir.mkdir(parents=True, exist_ok=True)
+
         self.db = sqlite3.connect(
             self.output_dir / f"{batch_name}_processed.db3",
             check_same_thread=False,
         )
+        self.db.row_factory = sqlite3.Row
         self.cursor = self.db.cursor()
+
+        # Fixed SQL syntax error (removed colon after TEXT)
         sql = """
 CREATE TABLE IF NOT EXISTS processed (
     chunk_id INTEGER,
-    model_name: TEXT,
+    model_name TEXT,
     message_index INTEGER,
     timestamp DATETIME,
-    sender: TEXT,
+    sender TEXT,
     message TEXT,
     responsive BOOLEAN,
     reason TEXT,
@@ -62,17 +86,18 @@ CREATE TABLE IF NOT EXISTS processed (
 );
         """
         self.cursor.execute(sql)
-        self.db.row_factory = sqlite3.Row
-        self.logger.info("summary database initialized")
+        self.db.commit()  # Added commit
+        self.logger.info("Summary database initialized")
+
+        # Thread lock for database operations
+        self.db_lock = threading.Lock()
 
     def _create_user_prompt(self, chunk: Chunk) -> str:
         """Create inference request for a chunk"""
         # Format messages
         messages_text = ""
         for msg in chunk.messages:
-            messages_text += (
-                f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message}\n"
-            )
+            messages_text += f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message_normalized}\n"
 
         # Create full prompt
         prompt = f"""
@@ -87,14 +112,19 @@ Provide your response as valid JSON following the specified format.
         return prompt
 
     def _create_chunks(self) -> list[Chunk]:
-        with open("pipeline_output/chunks.json", "r") as f:
+        """Create chunks from preprocessed data"""
+        chunks_file = self.output_dir / "chunks.json"
+
+        with open(chunks_file, "r") as f:
             chunk_data = json.load(f)
 
-        msg_df = pd.read_csv(self.output_dir / "preprocessed_messages.csv")
+        msg_file = self.output_dir / "preprocessed_messages.csv"
 
-        # Reconstruct chunks (simplified)
+        msg_df = pd.read_csv(msg_file)
+
+        # Reconstruct chunks
         chunks = []
-        for item in chunk_data["filtered_chunks"][:10]:  # First 10 for testing
+        for item in chunk_data[:2]:  # First 10 for testing
             chunk = Chunk(
                 chunk_id=item["chunk_id"],
                 start_line=item["start_line"],
@@ -102,60 +132,90 @@ Provide your response as valid JSON following the specified format.
                 messages=[],
                 combined_text="",
                 timestamp_start=item["timestamp_start"],
-                timestamp_end=item["timetamp_end"],
+                timestamp_end=item.get("timestamp_end", ""),
             )
+
             chunk_messages = []
-            dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"] - 1]
-            for index, row in dfRange.itertuples():
+            dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"]]
+            # index,timestamp,sender,message,line_number,message_normalized
+            for row in dfRange.itertuples():
                 message = Message(
-                    (index + 1),
-                    row["timestamp"],
-                    row["sender"],
-                    row["message_normalized"],
+                    line_number=row[4],
+                    timestamp=row[1],
+                    sender=row[2],
+                    message=row[3],
+                    message_normalized=row[5],
                 )
+                chunk_messages.append(message)
 
+            chunk.messages = chunk_messages
             chunks.append(chunk)
         return chunks
 
-    def run_inference(self, temperature: float = 0.1, max_tokens: int = 2048):
-        """Run inference on both models"""
+    def run_inference(self):
+        """Run inference on both models with parallel processing"""
         self.logger.info("=" * 80)
-        self.logger.info("RUNNING DUAL QWEN INFERENCE")
+        self.logger.info("RUNNING DUAL QWEN INFERENCE WITH PARALLEL PROCESSING")
         self.logger.info("=" * 80)
 
         chunks = self._create_chunks()
+        if not chunks:
+            self.logger.error("No chunks found. Exiting.")
+            return
+
+        # Run both models in parallel
+        with concurrent.futures.ThreadPoolExecutor(
+            max_workers=self.max_workers
+        ) as executor:
+            # Submit both model inference tasks
+            qwen3_future = executor.submit(
+                self._run_model_inference,
+                chunks,
+                self.qwen3_url,
+                self.qwen3_model_name,
+            )
 
-        self.logger.info("\nRunning Qwen 3 235B inference...")
-        self._run_model_inference(
-            chunks, self.qwen3_url, "Qwen3-235B", temperature, max_tokens
-        )
+            # qwen25_future = executor.submit(
+            #     self._run_model_inference,
+            #     chunks,
+            #     self.qwen25_url,
+            #     self.qwen25_model_name,
+            # )
 
-        self.logger.info("\nRunning Qwen 2.5 72B inference...")
-        self._run_model_inference(
-            chunks, self.qwen25_url, "Qwen2.5-72B", temperature, max_tokens
-        )
+            # Wait for both to complete
+            qwen3_success, qwen3_errors = qwen3_future.result()
+            # qwen25_success, qwen25_errors = qwen25_future.result()
 
+        self.logger.info(
+            f"\nQwen3-235B Results: {qwen3_success} success, {qwen3_errors} errors"
+        )
+        # self.logger.info(
+        #     f"Qwen2.5-72B Results: {qwen25_success} success, {qwen25_errors} errors"
+        # )
         self.logger.info("\n" + "=" * 80)
         self.logger.info("INFERENCE COMPLETE")
         self.logger.info("=" * 80)
 
     def _create_system_prompt(self) -> str:
         """Create system prompt for LLM"""
-        prompt = ""
-        with Path(self.output_dir, "system_prompt.txt").open("r") as file:
+        prompt_file = Path(self.output_dir, "system_prompt.txt")
+
+        with prompt_file.open("r") as file:
             prompt = file.read()
         return prompt
 
     def _create_response_format(self) -> str:
         """Create response format for LLM"""
-        response_format = ""
-        with Path(self.output_dir, "response_format.json").open() as file:
+        format_file = Path(self.output_dir, "response_format.json")
+
+        with format_file.open("r") as file:
             response_format = file.read()
         return response_format
 
-    def _check_existing_result(self, chunk: Chunk, model_name) -> bool:
+    def _check_existing_result(self, chunk: Chunk, model_name: str) -> bool:
         """Check if result already exists in db"""
-        sql = """
+        with self.db_lock:
+            sql = """
 SELECT 
     COUNT(*) AS num_messages
 FROM
@@ -164,21 +224,22 @@ WHERE
     model_name = ?
     AND chunk_id = ?
     AND responsive IS NOT NULL
-        """
-        result = self.cursor.execute(sql, (model_name, chunk.chunk_id))
-        row: dict = self.cursor.fetchone()
-        if row and row["0"] == len(chunk.messages):
-            return True
-        return False
+            """
+            result = self.cursor.execute(sql, (model_name, chunk.chunk_id))
+            row = self.cursor.fetchone()
+            if row and row["num_messages"] == len(chunk.messages):
+                return True
+            return False
 
     def _save_result(self, chunk: Chunk, results: list[dict], model_name: str):
         """Save result to db"""
-        # merge the chunk messages with the results
+        # Merge the chunk messages with the results
         merged_results = {}
         for msg in chunk.messages:
             merged_results[msg.line_number] = {"message": msg}
+
         for item in results:
-            if item["message_index"] in merged_results:
+            if item["message_index"] in merged_results.keys():
                 merged_results[item["message_index"]].update(item)
             else:
                 merged_results[item["message_index"]] = item.copy()
@@ -206,48 +267,55 @@ ON CONFLICT (model_name, message_index) DO UPDATE SET
     criteria = excluded.criteria,
     confidence = excluded.confidence
         """
-        for result in merged_results:
-            msg = result.get("message", None)
 
-            if msg is None or not isinstance(msg, Message):
-                self.logger.error(
-                    f"somehow we have a result without a message: \n{result}"
+        with self.db_lock:
+            for msg_idx, result_dict in merged_results.items():
+                msg = result_dict.get("message", None)
+
+                if msg is None:
+                    self.logger.error(
+                        f"Result without a message for line {msg_idx}: \n{result_dict}"
+                    )
+                    continue
+                elif not isinstance(msg, Message):
+                    self.logger.error(
+                        f"Message not a Message for line {msg_idx}: \n{result_dict}"
+                    )
+                    continue
+
+                self.cursor.execute(
+                    sql,
+                    (
+                        result_dict.get("chunk_id", chunk.chunk_id),
+                        model_name,
+                        msg.line_number,
+                        msg.timestamp,
+                        msg.sender,
+                        msg.message,
+                        result_dict.get("responsive", None),
+                        result_dict.get("reason", None),
+                        str(result_dict.get("criteria", [])),  # Convert to string
+                        result_dict.get("confidence", None),
+                    ),
                 )
-                continue
-
-            self.cursor.execute(
-                sql,
-                (
-                    result.get("chunk_id", None),
-                    model_name,
-                    msg.line_number,
-                    msg.timestamp,
-                    msg.sender,
-                    msg.message,
-                    result.get("responsive", None),
-                    result.get("reason", None),
-                    result.get("criteria", None),
-                ),
-            )
+            self.db.commit()
 
-    def _run_model_inference(
+    def _process_chunk_batch(
         self,
-        chunks: List[Chunk],
+        chunk_batch: List[Chunk],
         model_url: str,
         model_name: str,
-        temperature: float,
-        max_tokens: int,
-    ):
-        """Run inference on a single model"""
-        system_prompt = self._create_system_prompt()
-        response_format = self._create_response_format()
-
+        system_prompt: str,
+        response_format: str,
+    ) -> tuple[int, int]:
+        """Process a batch of chunks"""
         success = 0
         errors = 0
 
-        for chunk in tqdm(chunks, desc=f"{model_name} inference"):
-            # check if this chunk has already been processed
+        for chunk in chunk_batch:
+            # Check if this chunk has already been processed
             if self._check_existing_result(chunk, model_name):
+                success += 1
                 continue
 
             prompt_messages = []
@@ -259,8 +327,9 @@ ON CONFLICT (model_name, message_index) DO UPDATE SET
             payload = {
                 "model": model_name,
                 "messages": prompt_messages,
-                "temperature": temperature,
-                "max_tokens": max_tokens,
+                "temperature": self.temperature,
+                "top_p": self.top_p,
+                "max_tokens": self.max_tokens,
                 "response_format": {
                     "type": "json_schema",
                     "json_schema": {
@@ -270,28 +339,25 @@ ON CONFLICT (model_name, message_index) DO UPDATE SET
                 },
             }
 
-            # "top_p",
-            # "top_k",
-            # "frequency_penalty",
-            # "presence_penalty",
-            # # "stop",
-            # # "skip_special_tokens",
-            # "enable_thinking",
+            self.logger.debug(f"Payload:\n{str(payload)[:200]}...")
 
             headers = {"Content-Type": "application/json"}
-
-            response = "Not Processed"
+            if os.getenv("AI_TOKEN"):
+                headers["Authorization"] = f"Bearer {os.getenv('AI_TOKEN')}"
+            response = None
 
             try:
                 response = requests.post(
-                    f"{model_url}/v1/completions", headers=headers, json=payload
+                    f"{model_url}/v1/chat/completions",
+                    headers=headers,
+                    json=payload,
+                    timeout=600,
                 )
                 response.raise_for_status()
-                # logger.log(LEVEL_TRACE, f"Response {response.status_code}\n{response.text}")
 
                 data = response.json()
                 if "error" in data:
-                    raise RuntimeError("LLM error")
+                    raise RuntimeError(f"LLM error: {data['error']}")
 
                 choices = data.get("choices", [])
                 if not choices:
@@ -308,28 +374,105 @@ ON CONFLICT (model_name, message_index) DO UPDATE SET
 
                 result = self._parse_response(response_text, chunk, model_name)
                 if result:
+                    self._save_result(chunk, result, model_name)
                     success += 1
                 else:
                     raise RuntimeError("Could not parse result")
 
             except Exception as e:
                 self.logger.error(
-                    f"Error processing chunk {chunk.chunk_id}: \nResponse was:\n{response}\n{e.with_traceback}"
+                    f"Error processing chunk {chunk.chunk_id} with {model_name}: {str(e)}\n{traceback.format_exc()}"
                 )
+                if response:
+                    self.logger.error(f"Response status: {response.status_code}")
+                    self.logger.error(f"Response text: {response.text[:500]}...")
+
                 self._save_result(chunk, [], model_name)
                 errors += 1
+
         return success, errors
 
+    def _run_model_inference(
+        self,
+        chunks: List[Chunk],
+        model_url: str,
+        model_name: str,
+    ) -> tuple[int, int]:
+        """Run inference on a single model with parallel processing"""
+        system_prompt = self._create_system_prompt()
+        response_format = self._create_response_format()
+
+        total_success = 0
+        total_errors = 0
+
+        # Split chunks into batches of batch_size
+        chunk_batches = [
+            chunks[i : i + self.batch_size]
+            for i in range(0, len(chunks), self.batch_size)
+        ]
+
+        self.logger.info(
+            f"Processing {len(chunks)} chunks in {len(chunk_batches)} batches of up to {self.batch_size} chunks each for {model_name}"
+        )
+
+        # Process batches in parallel with max_workers threads
+        with concurrent.futures.ThreadPoolExecutor(
+            max_workers=self.max_workers
+        ) as executor:
+            # Submit all batch processing tasks
+            future_to_batch = {}
+            for i, batch in enumerate(chunk_batches):
+                future = executor.submit(
+                    self._process_chunk_batch,
+                    batch,
+                    model_url,
+                    model_name,
+                    system_prompt,
+                    response_format,
+                )
+                future_to_batch[future] = i
+
+            # Process completed batches with progress bar
+            with tqdm(total=len(chunk_batches), desc=f"{model_name} batches") as pbar:
+                for future in concurrent.futures.as_completed(future_to_batch):
+                    batch_idx = future_to_batch[future]
+                    try:
+                        success, errors = future.result()
+                        total_success += success
+                        total_errors += errors
+                        pbar.set_postfix(
+                            {
+                                "success": total_success,
+                                "errors": total_errors,
+                                "batch": f"{batch_idx + 1}/{len(chunk_batches)}",
+                            }
+                        )
+                    except Exception as e:
+                        self.logger.error(
+                            f"Batch {batch_idx} failed: {str(e)}\n{traceback.format_exc()}"
+                        )
+                        total_errors += len(chunk_batches[batch_idx])
+                    finally:
+                        pbar.update(1)
+
+        return total_success, total_errors
+
     def _parse_response(
-        self, response_text, chunk: Chunk, model_name: str
+        self, response_text: str, chunk: Chunk, model_name: str
     ) -> list[dict]:
         """Parse model response"""
-        parsed_list = {}
         try:
             parsed = loads(response_text)
-            parsed_list = cast(List[Dict], parsed)
+            # Handle both list and dict responses
+            if isinstance(parsed, dict):
+                parsed_list = [parsed]
+            else:
+                parsed_list = cast(List[Dict], parsed)
         except Exception as e:
-            self.logger.error(f"Errror parsing response for chunk {chunk.chunk_id}")
+            self.logger.error(
+                f"Error parsing response for chunk {chunk.chunk_id}: {str(e)}"
+            )
+            return []
 
         if not parsed_list:
             return []
@@ -348,24 +491,73 @@ ON CONFLICT (model_name, message_index) DO UPDATE SET
                     }
                 )
             except Exception as e:
-                self.logger.error(
-                    f"Error parsing response line: \n{e.with_traceback}\n{result}"
-                )
+                self.logger.error(f"Error parsing response line: {str(e)}\n{result}")
         return responses
 
+    def __del__(self):
+        """Clean up database connection"""
+        if hasattr(self, "db"):
+            self.db.close()
+
 
 if __name__ == "__main__":
     import argparse
 
-    parser = argparse.ArgumentParser(description="Run dual Qwen inference")
+    parser = argparse.ArgumentParser(
+        description="Run dual Qwen inference with parallel processing"
+    )
     parser.add_argument("batch_name")
+    parser.add_argument(
+        "--qwen3-model-name", default="Qwen/Qwen3-30B-A3B-Instruct-2507-FP8"
+    )
     parser.add_argument("--qwen3-url", default="http://localhost:8001")
+    parser.add_argument("--qwen25-model-name", default="Qwen/Qwen2.5-72B-Instruct-AWQ")
     parser.add_argument("--qwen25-url", default="http://localhost:8002")
     parser.add_argument("--output-dir", default="./pipeline_output")
+    parser.add_argument(
+        "--max-workers",
+        type=int,
+        default=4,
+        help="Maximum number of parallel workers per model",
+    )
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=4,
+        help="Number of chunks in each batch",
+    )
+    parser.add_argument(
+        "--temperature",
+        type=float,
+        default=0.01,
+        help="LLM Temperature setting",
+    )
+    parser.add_argument(
+        "--top-p",
+        type=float,
+        default=0.95,
+        help="LLM Top P setting",
+    )
+    parser.add_argument(
+        "--max-tokens",
+        type=int,
+        default=8192,
+        help="Maximum number of tokens to use for each batch",
+    )
 
     args = parser.parse_args()
 
     runner = InferenceRunner(
-        args.batch_name, args.qwen3_url, args.qwen25_url, args.output_dir
+        batch_name=args.batch_name,
+        qwen3_model_name=args.qwen3_model_name,
+        qwen3_url=args.qwen3_url,
+        qwen25_model_name=args.qwen25_model_name,
+        qwen25_url=args.qwen25_url,
+        output_dir=args.output_dir,
+        max_workers=args.max_workers,
+        batch_size=args.batch_size,
+        temperature=args.temperature,
+        top_p=args.top_p,
+        max_tokens=args.max_tokens,
     )
     runner.run_inference()