step7_inference_prep.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. """
  2. Step 7: Prepare data for dual Qwen inference.
  3. """
  4. from typing import List, Optional
  5. from pathlib import Path
  6. import json
  7. from pipeline.models.base import PipelineStep
  8. from pipeline.common_defs import Chunk, CASE_NAME, SUBPOENA_CRITERIA
  9. class InferencePreparation(PipelineStep):
  10. """Prepare inference requests for Qwen models"""
  11. def __init__(self, few_shot_file: Optional[str] = None,
  12. output_dir: str = './pipeline_output'):
  13. super().__init__(output_dir)
  14. self.few_shot_file = few_shot_file
  15. def execute(self, chunks: List[Chunk]) -> str:
  16. """
  17. Prepare inference requests for dual Qwen models.
  18. Args:
  19. chunks: List of filtered chunks
  20. Returns:
  21. Path to inference requests file
  22. """
  23. self.logger.info("Preparing data for dual Qwen inference...")
  24. self.logger.info(f" Primary: Qwen 3 235B (state-of-the-art)")
  25. self.logger.info(f" Secondary: Qwen 2.5 72B (proven accuracy)")
  26. # Load few-shot examples if provided
  27. few_shot_prompt = self._load_few_shot_examples()
  28. # Create system prompt
  29. system_prompt = self._create_system_prompt()
  30. # Create inference requests
  31. requests = []
  32. for chunk in chunks:
  33. request = self._create_request(chunk, system_prompt, few_shot_prompt)
  34. requests.append(request)
  35. # Save requests
  36. filepath = self._save_requests(requests)
  37. self.logger.info(f"Created {len(requests):,} inference requests")
  38. self.logger.info(f"Saved to: {filepath}")
  39. return str(filepath)
  40. def _load_few_shot_examples(self) -> str:
  41. """Load few-shot examples from attorney labels"""
  42. if not self.few_shot_file:
  43. return ""
  44. filepath = Path(self.few_shot_file)
  45. if not filepath.exists():
  46. self.logger.warning(f"Few-shot file not found: {filepath}")
  47. return ""
  48. self.logger.info(f"Loading few-shot examples from: {filepath}")
  49. # Parse attorney labels and create examples
  50. # (Simplified - would need actual parser for completed template)
  51. few_shot = "\n\nHere are examples of how to classify messages:\n"
  52. few_shot += "[Attorney-labeled examples would be inserted here]\n"
  53. return few_shot
  54. def _create_system_prompt(self) -> str:
  55. """Create system prompt for LLM"""
  56. criteria_text = "\n".join([
  57. f"{num}. {desc}"
  58. for num, desc in SUBPOENA_CRITERIA.items()
  59. ])
  60. prompt = f"""You are a legal document review specialist analyzing Signal chat messages for a discrimination lawsuit.
  61. CASE: {CASE_NAME}
  62. CLAIM: Discrimination based on gender identity
  63. SUBPOENA CRITERIA - Messages are responsive if they relate to:
  64. {criteria_text}
  65. IMPORTANT: Err on side of OVER-INCLUSION (high recall)."""
  66. return prompt
  67. def _create_request(self, chunk: Chunk, system_prompt: str,
  68. few_shot_prompt: str) -> dict:
  69. """Create inference request for a chunk"""
  70. # Format messages
  71. messages_text = ""
  72. for msg in chunk.messages:
  73. messages_text += f"Line {msg.line_number} [{msg.sender}]: {msg.message}\n"
  74. # Create full prompt
  75. prompt = f"""{system_prompt}
  76. {few_shot_prompt}
  77. MESSAGES TO REVIEW (Lines {chunk.start_line}-{chunk.end_line}):
  78. {messages_text}
  79. Respond with JSON:
  80. {{
  81. "responsive_line_numbers": [list of responsive line numbers],
  82. "reasoning": "brief explanation",
  83. "confidence": "high/medium/low"
  84. }}"""
  85. return {
  86. 'chunk_id': chunk.chunk_id,
  87. 'start_line': chunk.start_line,
  88. 'end_line': chunk.end_line,
  89. 'prompt': prompt,
  90. 'num_messages': len(chunk.messages)
  91. }
  92. def _save_requests(self, requests: List[dict]) -> Path:
  93. """Save inference requests to JSONL"""
  94. filepath = self.output_dir / 'dual_qwen_inference_requests.jsonl'
  95. with open(filepath, 'w') as f:
  96. for req in requests:
  97. f.write(json.dumps(req) + '\n')
  98. return filepath
  99. if __name__ == "__main__":
  100. # Example usage
  101. import json
  102. from pipeline.common_defs import Chunk, Message
  103. with open('pipeline_output/semantic_filtered_chunks.json', 'r') as f:
  104. data = json.load(f)
  105. # Reconstruct chunks (simplified)
  106. chunks = []
  107. for item in data['filtered_chunks'][:10]: # First 10 for testing
  108. chunk = Chunk(
  109. chunk_id=item['chunk_id'],
  110. start_line=item['start_line'],
  111. end_line=item['end_line'],
  112. messages=[Message(item['start_line'], "", "Sender", "Sample", "")],
  113. combined_text="",
  114. timestamp_start="",
  115. timestamp_end=""
  116. )
  117. chunks.append(chunk)
  118. prep = InferencePreparation()
  119. requests_file = prep.execute(chunks)
  120. print(f"Requests file: {requests_file}")