step7_inference_prep.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. """
  2. Step 7: Prepare data for dual Qwen inference.
  3. """
  4. from typing import List, Optional
  5. from pathlib import Path
  6. import json
  7. from pipeline.models.base import PipelineStep
  8. from pipeline.common_defs import Chunk, CASE_NAME, SUBPOENA_CRITERIA
  9. class InferencePreparation(PipelineStep):
  10. """Prepare inference requests for Qwen models"""
  11. def __init__(self, few_shot_file: Optional[str] = None,
  12. output_dir: str = './pipeline_output'):
  13. super().__init__(output_dir)
  14. self.few_shot_file = few_shot_file
  15. def execute(self, chunks: List[Chunk]) -> str:
  16. """
  17. Prepare inference requests for dual Qwen models.
  18. Args:
  19. chunks: List of filtered chunks
  20. Returns:
  21. Path to inference requests file
  22. """
  23. self.logger.info("Preparing data for dual Qwen inference...")
  24. self.logger.info(f" Primary: Qwen 3 235B (state-of-the-art)")
  25. self.logger.info(f" Secondary: Qwen 2.5 72B (proven accuracy)")
  26. # Create inference requests
  27. requests = []
  28. for chunk in chunks:
  29. request = self._create_request(chunk)
  30. requests.append(request)
  31. # Save requests
  32. filepath = self._save_requests(requests)
  33. self.logger.info(f"Created {len(requests):,} inference requests")
  34. self.logger.info(f"Saved to: {filepath}")
  35. return str(filepath)
  36. def _create_request(self, chunk: Chunk) -> dict:
  37. """Create inference request for a chunk"""
  38. # Format messages
  39. messages_text = ""
  40. for msg in chunk.messages:
  41. messages_text += f"Line {msg.line_number} [{msg.sender}]: {msg.message}\n"
  42. # Create full prompt
  43. prompt = f"""
  44. Review and classify the following messages.
  45. MESSAGES TO REVIEW (Lines {chunk.start_line}-{chunk.end_line}):
  46. {messages_text}
  47. Provide your response as valid JSON following the specified format.
  48. """
  49. return {
  50. "chunk_id": chunk.chunk_id,
  51. "start_line": chunk.start_line,
  52. "end_line": chunk.end_line,
  53. "prompt": prompt,
  54. "num_messages": len(chunk.messages),
  55. }
  56. def _save_requests(self, requests: List[dict]) -> Path:
  57. """Save inference requests to JSONL"""
  58. filepath = self.output_dir / "inference_requests.jsonl"
  59. with open(filepath, 'w') as f:
  60. for req in requests:
  61. f.write(json.dumps(req) + '\n')
  62. return filepath
  63. if __name__ == "__main__":
  64. # Example usage
  65. import json
  66. from pipeline.common_defs import Chunk, Message
  67. with open('pipeline_output/semantic_filtered_chunks.json', 'r') as f:
  68. data = json.load(f)
  69. # Reconstruct chunks (simplified)
  70. chunks = []
  71. for item in data['filtered_chunks'][:10]: # First 10 for testing
  72. chunk = Chunk(
  73. chunk_id=item['chunk_id'],
  74. start_line=item['start_line'],
  75. end_line=item['end_line'],
  76. messages=[Message(item['start_line'], "", "Sender", "Sample", "")],
  77. combined_text="",
  78. timestamp_start="",
  79. timestamp_end=""
  80. )
  81. chunks.append(chunk)
  82. prep = InferencePreparation()
  83. requests_file = prep.execute(chunks)
  84. print(f"Requests file: {requests_file}")