|
|
@@ -1,37 +1,55 @@
|
|
|
"""
|
|
|
-Inference runner for dual Qwen models.
|
|
|
+Improved Inference runner for dual Qwen models with parallel processing.
|
|
|
"""
|
|
|
|
|
|
-from collections import defaultdict
|
|
|
import json
|
|
|
+import os
|
|
|
+import traceback
|
|
|
import pandas as pd
|
|
|
import requests
|
|
|
-from typing import List, Dict, TypedDict, cast
|
|
|
+from typing import List, Dict, cast
|
|
|
from pathlib import Path
|
|
|
import logging
|
|
|
from tqdm import tqdm
|
|
|
import sqlite3
|
|
|
-from json_repair import repair_json, loads
|
|
|
+from json_repair import loads
|
|
|
+import concurrent.futures
|
|
|
+import threading
|
|
|
|
|
|
from pipeline.common_defs import Chunk, Message
|
|
|
|
|
|
+
|
|
|
class InferenceRunner:
|
|
|
- """Run inference on dual Qwen models"""
|
|
|
+ """Run inference on dual Qwen models with parallel processing"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
batch_name: str,
|
|
|
- qwen3_url: str = "http://localhost:8000",
|
|
|
- qwen25_url: str = "http://localhost:8001",
|
|
|
+ qwen3_model_name: str = "",
|
|
|
+ qwen3_url: str = "http://localhost:8001",
|
|
|
+ qwen25_model_name: str = "",
|
|
|
+ qwen25_url: str = "http://localhost:8002",
|
|
|
output_dir: str = "./pipeline_output",
|
|
|
+ max_workers: int = 4, # New parameter for parallel processing
|
|
|
+ batch_size: int = 4,
|
|
|
+ temperature: float = 0.01,
|
|
|
+ top_p: float = 0.95,
|
|
|
+ max_tokens: int = 8192,
|
|
|
):
|
|
|
self.batch_name = batch_name
|
|
|
+ self.qwen3_model_name = qwen3_model_name
|
|
|
self.qwen3_url = qwen3_url
|
|
|
+ self.qwen25_model_name = qwen25_model_name
|
|
|
self.qwen25_url = qwen25_url
|
|
|
self.output_dir = Path(output_dir)
|
|
|
+ self.max_workers = max_workers
|
|
|
+ self.batch_size = batch_size
|
|
|
+ self.max_tokens = max_tokens
|
|
|
+ self.temperature = temperature
|
|
|
+ self.top_p = top_p
|
|
|
|
|
|
self.logger = logging.getLogger("InferenceRunner")
|
|
|
- self.logger.setLevel(logging.INFO)
|
|
|
+ self.logger.setLevel(logging.DEBUG)
|
|
|
|
|
|
if not self.logger.handlers:
|
|
|
handler = logging.StreamHandler()
|
|
|
@@ -41,18 +59,24 @@ class InferenceRunner:
|
|
|
handler.setFormatter(formatter)
|
|
|
self.logger.addHandler(handler)
|
|
|
|
|
|
+ # Create output directory if it doesn't exist
|
|
|
+ self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
self.db = sqlite3.connect(
|
|
|
self.output_dir / f"{batch_name}_processed.db3",
|
|
|
check_same_thread=False,
|
|
|
)
|
|
|
+ self.db.row_factory = sqlite3.Row
|
|
|
self.cursor = self.db.cursor()
|
|
|
+
|
|
|
+ # Fixed SQL syntax error (removed colon after TEXT)
|
|
|
sql = """
|
|
|
CREATE TABLE IF NOT EXISTS processed (
|
|
|
chunk_id INTEGER,
|
|
|
- model_name: TEXT,
|
|
|
+ model_name TEXT,
|
|
|
message_index INTEGER,
|
|
|
timestamp DATETIME,
|
|
|
- sender: TEXT,
|
|
|
+ sender TEXT,
|
|
|
message TEXT,
|
|
|
responsive BOOLEAN,
|
|
|
reason TEXT,
|
|
|
@@ -62,17 +86,18 @@ CREATE TABLE IF NOT EXISTS processed (
|
|
|
);
|
|
|
"""
|
|
|
self.cursor.execute(sql)
|
|
|
- self.db.row_factory = sqlite3.Row
|
|
|
- self.logger.info("summary database initialized")
|
|
|
+ self.db.commit() # Added commit
|
|
|
+ self.logger.info("Summary database initialized")
|
|
|
+
|
|
|
+ # Thread lock for database operations
|
|
|
+ self.db_lock = threading.Lock()
|
|
|
|
|
|
def _create_user_prompt(self, chunk: Chunk) -> str:
|
|
|
"""Create inference request for a chunk"""
|
|
|
# Format messages
|
|
|
messages_text = ""
|
|
|
for msg in chunk.messages:
|
|
|
- messages_text += (
|
|
|
- f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message}\n"
|
|
|
- )
|
|
|
+ messages_text += f"#{msg.line_number} [{msg.timestamp}] [{msg.sender}]: {msg.message_normalized}\n"
|
|
|
|
|
|
# Create full prompt
|
|
|
prompt = f"""
|
|
|
@@ -87,14 +112,19 @@ Provide your response as valid JSON following the specified format.
|
|
|
return prompt
|
|
|
|
|
|
def _create_chunks(self) -> list[Chunk]:
|
|
|
- with open("pipeline_output/chunks.json", "r") as f:
|
|
|
+ """Create chunks from preprocessed data"""
|
|
|
+ chunks_file = self.output_dir / "chunks.json"
|
|
|
+
|
|
|
+ with open(chunks_file, "r") as f:
|
|
|
chunk_data = json.load(f)
|
|
|
|
|
|
- msg_df = pd.read_csv(self.output_dir / "preprocessed_messages.csv")
|
|
|
+ msg_file = self.output_dir / "preprocessed_messages.csv"
|
|
|
|
|
|
- # Reconstruct chunks (simplified)
|
|
|
+ msg_df = pd.read_csv(msg_file)
|
|
|
+
|
|
|
+ # Reconstruct chunks
|
|
|
chunks = []
|
|
|
- for item in chunk_data["filtered_chunks"][:10]: # First 10 for testing
|
|
|
+ for item in chunk_data[:2]: # First 10 for testing
|
|
|
chunk = Chunk(
|
|
|
chunk_id=item["chunk_id"],
|
|
|
start_line=item["start_line"],
|
|
|
@@ -102,60 +132,90 @@ Provide your response as valid JSON following the specified format.
|
|
|
messages=[],
|
|
|
combined_text="",
|
|
|
timestamp_start=item["timestamp_start"],
|
|
|
- timestamp_end=item["timetamp_end"],
|
|
|
+ timestamp_end=item.get("timestamp_end", ""),
|
|
|
)
|
|
|
+
|
|
|
chunk_messages = []
|
|
|
- dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"] - 1]
|
|
|
- for index, row in dfRange.itertuples():
|
|
|
+ dfRange = msg_df.iloc[item["start_line"] - 1 : item["end_line"]]
|
|
|
+ # index,timestamp,sender,message,line_number,message_normalized
|
|
|
+ for row in dfRange.itertuples():
|
|
|
message = Message(
|
|
|
- (index + 1),
|
|
|
- row["timestamp"],
|
|
|
- row["sender"],
|
|
|
- row["message_normalized"],
|
|
|
+ line_number=row[4],
|
|
|
+ timestamp=row[1],
|
|
|
+ sender=row[2],
|
|
|
+ message=row[3],
|
|
|
+ message_normalized=row[5],
|
|
|
)
|
|
|
+ chunk_messages.append(message)
|
|
|
|
|
|
+ chunk.messages = chunk_messages
|
|
|
chunks.append(chunk)
|
|
|
return chunks
|
|
|
|
|
|
- def run_inference(self, temperature: float = 0.1, max_tokens: int = 2048):
|
|
|
- """Run inference on both models"""
|
|
|
+ def run_inference(self):
|
|
|
+ """Run inference on both models with parallel processing"""
|
|
|
self.logger.info("=" * 80)
|
|
|
- self.logger.info("RUNNING DUAL QWEN INFERENCE")
|
|
|
+ self.logger.info("RUNNING DUAL QWEN INFERENCE WITH PARALLEL PROCESSING")
|
|
|
self.logger.info("=" * 80)
|
|
|
|
|
|
chunks = self._create_chunks()
|
|
|
+ if not chunks:
|
|
|
+ self.logger.error("No chunks found. Exiting.")
|
|
|
+ return
|
|
|
+
|
|
|
+ # Run both models in parallel
|
|
|
+ with concurrent.futures.ThreadPoolExecutor(
|
|
|
+ max_workers=self.max_workers
|
|
|
+ ) as executor:
|
|
|
+ # Submit both model inference tasks
|
|
|
+ qwen3_future = executor.submit(
|
|
|
+ self._run_model_inference,
|
|
|
+ chunks,
|
|
|
+ self.qwen3_url,
|
|
|
+ self.qwen3_model_name,
|
|
|
+ )
|
|
|
|
|
|
- self.logger.info("\nRunning Qwen 3 235B inference...")
|
|
|
- self._run_model_inference(
|
|
|
- chunks, self.qwen3_url, "Qwen3-235B", temperature, max_tokens
|
|
|
- )
|
|
|
+ # qwen25_future = executor.submit(
|
|
|
+ # self._run_model_inference,
|
|
|
+ # chunks,
|
|
|
+ # self.qwen25_url,
|
|
|
+ # self.qwen25_model_name,
|
|
|
+ # )
|
|
|
|
|
|
- self.logger.info("\nRunning Qwen 2.5 72B inference...")
|
|
|
- self._run_model_inference(
|
|
|
- chunks, self.qwen25_url, "Qwen2.5-72B", temperature, max_tokens
|
|
|
- )
|
|
|
+ # Wait for both to complete
|
|
|
+ qwen3_success, qwen3_errors = qwen3_future.result()
|
|
|
+ # qwen25_success, qwen25_errors = qwen25_future.result()
|
|
|
|
|
|
+ self.logger.info(
|
|
|
+ f"\nQwen3-235B Results: {qwen3_success} success, {qwen3_errors} errors"
|
|
|
+ )
|
|
|
+ # self.logger.info(
|
|
|
+ # f"Qwen2.5-72B Results: {qwen25_success} success, {qwen25_errors} errors"
|
|
|
+ # )
|
|
|
self.logger.info("\n" + "=" * 80)
|
|
|
self.logger.info("INFERENCE COMPLETE")
|
|
|
self.logger.info("=" * 80)
|
|
|
|
|
|
def _create_system_prompt(self) -> str:
|
|
|
"""Create system prompt for LLM"""
|
|
|
- prompt = ""
|
|
|
- with Path(self.output_dir, "system_prompt.txt").open("r") as file:
|
|
|
+ prompt_file = Path(self.output_dir, "system_prompt.txt")
|
|
|
+
|
|
|
+ with prompt_file.open("r") as file:
|
|
|
prompt = file.read()
|
|
|
return prompt
|
|
|
|
|
|
def _create_response_format(self) -> str:
|
|
|
"""Create response format for LLM"""
|
|
|
- response_format = ""
|
|
|
- with Path(self.output_dir, "response_format.json").open() as file:
|
|
|
+ format_file = Path(self.output_dir, "response_format.json")
|
|
|
+
|
|
|
+ with format_file.open("r") as file:
|
|
|
response_format = file.read()
|
|
|
return response_format
|
|
|
|
|
|
- def _check_existing_result(self, chunk: Chunk, model_name) -> bool:
|
|
|
+ def _check_existing_result(self, chunk: Chunk, model_name: str) -> bool:
|
|
|
"""Check if result already exists in db"""
|
|
|
- sql = """
|
|
|
+ with self.db_lock:
|
|
|
+ sql = """
|
|
|
SELECT
|
|
|
COUNT(*) AS num_messages
|
|
|
FROM
|
|
|
@@ -164,21 +224,22 @@ WHERE
|
|
|
model_name = ?
|
|
|
AND chunk_id = ?
|
|
|
AND responsive IS NOT NULL
|
|
|
- """
|
|
|
- result = self.cursor.execute(sql, (model_name, chunk.chunk_id))
|
|
|
- row: dict = self.cursor.fetchone()
|
|
|
- if row and row["0"] == len(chunk.messages):
|
|
|
- return True
|
|
|
- return False
|
|
|
+ """
|
|
|
+ result = self.cursor.execute(sql, (model_name, chunk.chunk_id))
|
|
|
+ row = self.cursor.fetchone()
|
|
|
+ if row and row["num_messages"] == len(chunk.messages):
|
|
|
+ return True
|
|
|
+ return False
|
|
|
|
|
|
def _save_result(self, chunk: Chunk, results: list[dict], model_name: str):
|
|
|
"""Save result to db"""
|
|
|
- # merge the chunk messages with the results
|
|
|
+ # Merge the chunk messages with the results
|
|
|
merged_results = {}
|
|
|
for msg in chunk.messages:
|
|
|
merged_results[msg.line_number] = {"message": msg}
|
|
|
+
|
|
|
for item in results:
|
|
|
- if item["message_index"] in merged_results:
|
|
|
+ if item["message_index"] in merged_results.keys():
|
|
|
merged_results[item["message_index"]].update(item)
|
|
|
else:
|
|
|
merged_results[item["message_index"]] = item.copy()
|
|
|
@@ -206,48 +267,55 @@ ON CONFLICT (model_name, message_index) DO UPDATE SET
|
|
|
criteria = excluded.criteria,
|
|
|
confidence = excluded.confidence
|
|
|
"""
|
|
|
- for result in merged_results:
|
|
|
- msg = result.get("message", None)
|
|
|
|
|
|
- if msg is None or not isinstance(msg, Message):
|
|
|
- self.logger.error(
|
|
|
- f"somehow we have a result without a message: \n{result}"
|
|
|
+ with self.db_lock:
|
|
|
+ for msg_idx, result_dict in merged_results.items():
|
|
|
+ msg = result_dict.get("message", None)
|
|
|
+
|
|
|
+ if msg is None:
|
|
|
+ self.logger.error(
|
|
|
+ f"Result without a message for line {msg_idx}: \n{result_dict}"
|
|
|
+ )
|
|
|
+ continue
|
|
|
+ elif not isinstance(msg, Message):
|
|
|
+ self.logger.error(
|
|
|
+ f"Message not a Message for line {msg_idx}: \n{result_dict}"
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ self.cursor.execute(
|
|
|
+ sql,
|
|
|
+ (
|
|
|
+ result_dict.get("chunk_id", chunk.chunk_id),
|
|
|
+ model_name,
|
|
|
+ msg.line_number,
|
|
|
+ msg.timestamp,
|
|
|
+ msg.sender,
|
|
|
+ msg.message,
|
|
|
+ result_dict.get("responsive", None),
|
|
|
+ result_dict.get("reason", None),
|
|
|
+ str(result_dict.get("criteria", [])), # Convert to string
|
|
|
+ result_dict.get("confidence", None),
|
|
|
+ ),
|
|
|
)
|
|
|
- continue
|
|
|
-
|
|
|
- self.cursor.execute(
|
|
|
- sql,
|
|
|
- (
|
|
|
- result.get("chunk_id", None),
|
|
|
- model_name,
|
|
|
- msg.line_number,
|
|
|
- msg.timestamp,
|
|
|
- msg.sender,
|
|
|
- msg.message,
|
|
|
- result.get("responsive", None),
|
|
|
- result.get("reason", None),
|
|
|
- result.get("criteria", None),
|
|
|
- ),
|
|
|
- )
|
|
|
+ self.db.commit()
|
|
|
|
|
|
- def _run_model_inference(
|
|
|
+ def _process_chunk_batch(
|
|
|
self,
|
|
|
- chunks: List[Chunk],
|
|
|
+ chunk_batch: List[Chunk],
|
|
|
model_url: str,
|
|
|
model_name: str,
|
|
|
- temperature: float,
|
|
|
- max_tokens: int,
|
|
|
- ):
|
|
|
- """Run inference on a single model"""
|
|
|
- system_prompt = self._create_system_prompt()
|
|
|
- response_format = self._create_response_format()
|
|
|
-
|
|
|
+ system_prompt: str,
|
|
|
+ response_format: str,
|
|
|
+ ) -> tuple[int, int]:
|
|
|
+ """Process a batch of chunks"""
|
|
|
success = 0
|
|
|
errors = 0
|
|
|
|
|
|
- for chunk in tqdm(chunks, desc=f"{model_name} inference"):
|
|
|
- # check if this chunk has already been processed
|
|
|
+ for chunk in chunk_batch:
|
|
|
+ # Check if this chunk has already been processed
|
|
|
if self._check_existing_result(chunk, model_name):
|
|
|
+ success += 1
|
|
|
continue
|
|
|
|
|
|
prompt_messages = []
|
|
|
@@ -259,8 +327,9 @@ ON CONFLICT (model_name, message_index) DO UPDATE SET
|
|
|
payload = {
|
|
|
"model": model_name,
|
|
|
"messages": prompt_messages,
|
|
|
- "temperature": temperature,
|
|
|
- "max_tokens": max_tokens,
|
|
|
+ "temperature": self.temperature,
|
|
|
+ "top_p": self.top_p,
|
|
|
+ "max_tokens": self.max_tokens,
|
|
|
"response_format": {
|
|
|
"type": "json_schema",
|
|
|
"json_schema": {
|
|
|
@@ -270,28 +339,25 @@ ON CONFLICT (model_name, message_index) DO UPDATE SET
|
|
|
},
|
|
|
}
|
|
|
|
|
|
- # "top_p",
|
|
|
- # "top_k",
|
|
|
- # "frequency_penalty",
|
|
|
- # "presence_penalty",
|
|
|
- # # "stop",
|
|
|
- # # "skip_special_tokens",
|
|
|
- # "enable_thinking",
|
|
|
+ self.logger.debug(f"Payload:\n{str(payload)[:200]}...")
|
|
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
|
-
|
|
|
- response = "Not Processed"
|
|
|
+ if os.getenv("AI_TOKEN"):
|
|
|
+ headers["Authorization"] = f"Bearer {os.getenv('AI_TOKEN')}"
|
|
|
+ response = None
|
|
|
|
|
|
try:
|
|
|
response = requests.post(
|
|
|
- f"{model_url}/v1/completions", headers=headers, json=payload
|
|
|
+ f"{model_url}/v1/chat/completions",
|
|
|
+ headers=headers,
|
|
|
+ json=payload,
|
|
|
+ timeout=600,
|
|
|
)
|
|
|
response.raise_for_status()
|
|
|
- # logger.log(LEVEL_TRACE, f"Response {response.status_code}\n{response.text}")
|
|
|
|
|
|
data = response.json()
|
|
|
if "error" in data:
|
|
|
- raise RuntimeError("LLM error")
|
|
|
+ raise RuntimeError(f"LLM error: {data['error']}")
|
|
|
|
|
|
choices = data.get("choices", [])
|
|
|
if not choices:
|
|
|
@@ -308,28 +374,105 @@ ON CONFLICT (model_name, message_index) DO UPDATE SET
|
|
|
|
|
|
result = self._parse_response(response_text, chunk, model_name)
|
|
|
if result:
|
|
|
+ self._save_result(chunk, result, model_name)
|
|
|
success += 1
|
|
|
else:
|
|
|
raise RuntimeError("Could not parse result")
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(
|
|
|
- f"Error processing chunk {chunk.chunk_id}: \nResponse was:\n{response}\n{e.with_traceback}"
|
|
|
+ f"Error processing chunk {chunk.chunk_id} with {model_name}: {str(e)}\n{traceback.format_exc()}"
|
|
|
)
|
|
|
+ if response:
|
|
|
+ self.logger.error(f"Response status: {response.status_code}")
|
|
|
+ self.logger.error(f"Response text: {response.text[:500]}...")
|
|
|
+
|
|
|
self._save_result(chunk, [], model_name)
|
|
|
errors += 1
|
|
|
+
|
|
|
return success, errors
|
|
|
|
|
|
+ def _run_model_inference(
|
|
|
+ self,
|
|
|
+ chunks: List[Chunk],
|
|
|
+ model_url: str,
|
|
|
+ model_name: str,
|
|
|
+ ) -> tuple[int, int]:
|
|
|
+ """Run inference on a single model with parallel processing"""
|
|
|
+ system_prompt = self._create_system_prompt()
|
|
|
+ response_format = self._create_response_format()
|
|
|
+
|
|
|
+ total_success = 0
|
|
|
+ total_errors = 0
|
|
|
+
|
|
|
+ # Split chunks into batches of batch_size
|
|
|
+ chunk_batches = [
|
|
|
+ chunks[i : i + self.batch_size]
|
|
|
+ for i in range(0, len(chunks), self.batch_size)
|
|
|
+ ]
|
|
|
+
|
|
|
+ self.logger.info(
|
|
|
+ f"Processing {len(chunks)} chunks in {len(chunk_batches)} batches of up to {self.batch_size} chunks each for {model_name}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Process batches in parallel with max_workers threads
|
|
|
+ with concurrent.futures.ThreadPoolExecutor(
|
|
|
+ max_workers=self.max_workers
|
|
|
+ ) as executor:
|
|
|
+ # Submit all batch processing tasks
|
|
|
+ future_to_batch = {}
|
|
|
+ for i, batch in enumerate(chunk_batches):
|
|
|
+ future = executor.submit(
|
|
|
+ self._process_chunk_batch,
|
|
|
+ batch,
|
|
|
+ model_url,
|
|
|
+ model_name,
|
|
|
+ system_prompt,
|
|
|
+ response_format,
|
|
|
+ )
|
|
|
+ future_to_batch[future] = i
|
|
|
+
|
|
|
+ # Process completed batches with progress bar
|
|
|
+ with tqdm(total=len(chunk_batches), desc=f"{model_name} batches") as pbar:
|
|
|
+ for future in concurrent.futures.as_completed(future_to_batch):
|
|
|
+ batch_idx = future_to_batch[future]
|
|
|
+ try:
|
|
|
+ success, errors = future.result()
|
|
|
+ total_success += success
|
|
|
+ total_errors += errors
|
|
|
+ pbar.set_postfix(
|
|
|
+ {
|
|
|
+ "success": total_success,
|
|
|
+ "errors": total_errors,
|
|
|
+ "batch": f"{batch_idx + 1}/{len(chunk_batches)}",
|
|
|
+ }
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(
|
|
|
+ f"Batch {batch_idx} failed: {str(e)}\n{traceback.format_exc()}"
|
|
|
+ )
|
|
|
+ total_errors += len(chunk_batches[batch_idx])
|
|
|
+ finally:
|
|
|
+ pbar.update(1)
|
|
|
+
|
|
|
+ return total_success, total_errors
|
|
|
+
|
|
|
def _parse_response(
|
|
|
- self, response_text, chunk: Chunk, model_name: str
|
|
|
+ self, response_text: str, chunk: Chunk, model_name: str
|
|
|
) -> list[dict]:
|
|
|
"""Parse model response"""
|
|
|
- parsed_list = {}
|
|
|
try:
|
|
|
parsed = loads(response_text)
|
|
|
- parsed_list = cast(List[Dict], parsed)
|
|
|
+ # Handle both list and dict responses
|
|
|
+ if isinstance(parsed, dict):
|
|
|
+ parsed_list = [parsed]
|
|
|
+ else:
|
|
|
+ parsed_list = cast(List[Dict], parsed)
|
|
|
except Exception as e:
|
|
|
- self.logger.error(f"Errror parsing response for chunk {chunk.chunk_id}")
|
|
|
+ self.logger.error(
|
|
|
+ f"Error parsing response for chunk {chunk.chunk_id}: {str(e)}"
|
|
|
+ )
|
|
|
+ return []
|
|
|
|
|
|
if not parsed_list:
|
|
|
return []
|
|
|
@@ -348,24 +491,73 @@ ON CONFLICT (model_name, message_index) DO UPDATE SET
|
|
|
}
|
|
|
)
|
|
|
except Exception as e:
|
|
|
- self.logger.error(
|
|
|
- f"Error parsing response line: \n{e.with_traceback}\n{result}"
|
|
|
- )
|
|
|
+ self.logger.error(f"Error parsing response line: {str(e)}\n{result}")
|
|
|
return responses
|
|
|
|
|
|
+ def __del__(self):
|
|
|
+ """Clean up database connection"""
|
|
|
+ if hasattr(self, "db"):
|
|
|
+ self.db.close()
|
|
|
+
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import argparse
|
|
|
|
|
|
- parser = argparse.ArgumentParser(description="Run dual Qwen inference")
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description="Run dual Qwen inference with parallel processing"
|
|
|
+ )
|
|
|
parser.add_argument("batch_name")
|
|
|
+ parser.add_argument(
|
|
|
+ "--qwen3-model-name", default="Qwen/Qwen3-30B-A3B-Instruct-2507-FP8"
|
|
|
+ )
|
|
|
parser.add_argument("--qwen3-url", default="http://localhost:8001")
|
|
|
+ parser.add_argument("--qwen25-model-name", default="Qwen/Qwen2.5-72B-Instruct-AWQ")
|
|
|
parser.add_argument("--qwen25-url", default="http://localhost:8002")
|
|
|
parser.add_argument("--output-dir", default="./pipeline_output")
|
|
|
+ parser.add_argument(
|
|
|
+ "--max-workers",
|
|
|
+ type=int,
|
|
|
+ default=4,
|
|
|
+ help="Maximum number of parallel workers per model",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--batch-size",
|
|
|
+ type=int,
|
|
|
+ default=4,
|
|
|
+ help="Number of chunks in each batch",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--temperature",
|
|
|
+ type=float,
|
|
|
+ default=0.01,
|
|
|
+ help="LLM Temperature setting",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--top-p",
|
|
|
+ type=float,
|
|
|
+ default=0.95,
|
|
|
+ help="LLM Top P setting",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--max-tokens",
|
|
|
+ type=int,
|
|
|
+ default=8192,
|
|
|
+ help="Maximum number of tokens to use for each batch",
|
|
|
+ )
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
runner = InferenceRunner(
|
|
|
- args.batch_name, args.qwen3_url, args.qwen25_url, args.output_dir
|
|
|
+ batch_name=args.batch_name,
|
|
|
+ qwen3_model_name=args.qwen3_model_name,
|
|
|
+ qwen3_url=args.qwen3_url,
|
|
|
+ qwen25_model_name=args.qwen25_model_name,
|
|
|
+ qwen25_url=args.qwen25_url,
|
|
|
+ output_dir=args.output_dir,
|
|
|
+ max_workers=args.max_workers,
|
|
|
+ batch_size=args.batch_size,
|
|
|
+ temperature=args.temperature,
|
|
|
+ top_p=args.top_p,
|
|
|
+ max_tokens=args.max_tokens,
|
|
|
)
|
|
|
runner.run_inference()
|