| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- """
- Step 7: Prepare data for dual Qwen inference.
- """
- from typing import List, Optional
- from pathlib import Path
- import json
- from pipeline.models.base import PipelineStep
- from pipeline.common_defs import Chunk, CASE_NAME, SUBPOENA_CRITERIA
- class InferencePreparation(PipelineStep):
- """Prepare inference requests for Qwen models"""
-
- def __init__(self, few_shot_file: Optional[str] = None,
- output_dir: str = './pipeline_output'):
- super().__init__(output_dir)
- self.few_shot_file = few_shot_file
-
- def execute(self, chunks: List[Chunk]) -> str:
- """
- Prepare inference requests for dual Qwen models.
-
- Args:
- chunks: List of filtered chunks
-
- Returns:
- Path to inference requests file
- """
- self.logger.info("Preparing data for dual Qwen inference...")
- self.logger.info(f" Primary: Qwen 3 235B (state-of-the-art)")
- self.logger.info(f" Secondary: Qwen 2.5 72B (proven accuracy)")
-
- # Load few-shot examples if provided
- few_shot_prompt = self._load_few_shot_examples()
-
- # Create system prompt
- system_prompt = self._create_system_prompt()
-
- # Create inference requests
- requests = []
- for chunk in chunks:
- request = self._create_request(chunk, system_prompt, few_shot_prompt)
- requests.append(request)
-
- # Save requests
- filepath = self._save_requests(requests)
-
- self.logger.info(f"Created {len(requests):,} inference requests")
- self.logger.info(f"Saved to: {filepath}")
-
- return str(filepath)
-
- def _load_few_shot_examples(self) -> str:
- """Load few-shot examples from attorney labels"""
- if not self.few_shot_file:
- return ""
-
- filepath = Path(self.few_shot_file)
- if not filepath.exists():
- self.logger.warning(f"Few-shot file not found: {filepath}")
- return ""
-
- self.logger.info(f"Loading few-shot examples from: {filepath}")
-
- # Parse attorney labels and create examples
- # (Simplified - would need actual parser for completed template)
- few_shot = "\n\nHere are examples of how to classify messages:\n"
- few_shot += "[Attorney-labeled examples would be inserted here]\n"
-
- return few_shot
-
- def _create_system_prompt(self) -> str:
- """Create system prompt for LLM"""
- criteria_text = "\n".join([
- f"{num}. {desc}"
- for num, desc in SUBPOENA_CRITERIA.items()
- ])
-
- prompt = f"""You are a legal document review specialist analyzing Signal chat messages for a discrimination lawsuit.
- CASE: {CASE_NAME}
- CLAIM: Discrimination based on gender identity
- SUBPOENA CRITERIA - Messages are responsive if they relate to:
- {criteria_text}
- IMPORTANT: Err on side of OVER-INCLUSION (high recall)."""
-
- return prompt
-
- def _create_request(self, chunk: Chunk, system_prompt: str,
- few_shot_prompt: str) -> dict:
- """Create inference request for a chunk"""
- # Format messages
- messages_text = ""
- for msg in chunk.messages:
- messages_text += f"Line {msg.line_number} [{msg.sender}]: {msg.message}\n"
-
- # Create full prompt
- prompt = f"""{system_prompt}
- {few_shot_prompt}
- MESSAGES TO REVIEW (Lines {chunk.start_line}-{chunk.end_line}):
- {messages_text}
- Respond with JSON:
- {{
- "responsive_line_numbers": [list of responsive line numbers],
- "reasoning": "brief explanation",
- "confidence": "high/medium/low"
- }}"""
-
- return {
- 'chunk_id': chunk.chunk_id,
- 'start_line': chunk.start_line,
- 'end_line': chunk.end_line,
- 'prompt': prompt,
- 'num_messages': len(chunk.messages)
- }
-
- def _save_requests(self, requests: List[dict]) -> Path:
- """Save inference requests to JSONL"""
- filepath = self.output_dir / 'dual_qwen_inference_requests.jsonl'
-
- with open(filepath, 'w') as f:
- for req in requests:
- f.write(json.dumps(req) + '\n')
-
- return filepath
- if __name__ == "__main__":
- # Example usage
- import json
- from pipeline.common_defs import Chunk, Message
-
- with open('pipeline_output/semantic_filtered_chunks.json', 'r') as f:
- data = json.load(f)
-
- # Reconstruct chunks (simplified)
- chunks = []
- for item in data['filtered_chunks'][:10]: # First 10 for testing
- chunk = Chunk(
- chunk_id=item['chunk_id'],
- start_line=item['start_line'],
- end_line=item['end_line'],
- messages=[Message(item['start_line'], "", "Sender", "Sample", "")],
- combined_text="",
- timestamp_start="",
- timestamp_end=""
- )
- chunks.append(chunk)
-
- prep = InferencePreparation()
- requests_file = prep.execute(chunks)
- print(f"Requests file: {requests_file}")
|