main_pipeline.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """
  2. Main pipeline orchestrator - runs all steps in sequence.
  3. """
  4. import sys
  5. from pathlib import Path
  6. from typing import Optional
  7. import logging
  8. # Add pipeline to path
  9. sys.path.insert(0, str(Path(__file__).parent.parent))
  10. from pipeline.steps.step1_load_data import DataLoader
  11. from pipeline.steps.step2_create_chunks import ChunkCreator
  12. from pipeline.steps.step3_keyword_filter import KeywordFilter
  13. from pipeline.steps.step4_semantic_filter import SemanticFilter
  14. from pipeline.steps.step5_random_sampling import RandomSampler
  15. from pipeline.steps.step6_labeling_template import LabelingTemplateGenerator
  16. from pipeline.steps.step7_inference_prep import InferencePreparation
  17. from pipeline.steps.step8_merge_results import ResultsMerger
  18. class DiscoveryPipeline:
  19. """Main pipeline orchestrator"""
  20. def __init__(self, csv_path: str, output_dir: str = './pipeline_output'):
  21. self.csv_path = csv_path
  22. self.output_dir = Path(output_dir)
  23. self.output_dir.mkdir(exist_ok=True)
  24. # Setup logging
  25. self.logger = self._setup_logger()
  26. # Initialize steps
  27. self.data_loader = DataLoader(csv_path, output_dir)
  28. self.chunk_creator = ChunkCreator(chunk_size=20, overlap=5, output_dir=output_dir)
  29. self.keyword_filter = KeywordFilter(output_dir)
  30. self.semantic_filter = SemanticFilter(
  31. threshold1=0.25,
  32. threshold2=0.25,
  33. merge_strategy='union',
  34. output_dir=output_dir
  35. )
  36. self.random_sampler = RandomSampler(n_samples=20, seed=42, output_dir=output_dir)
  37. self.template_generator = LabelingTemplateGenerator(output_dir)
  38. self.inference_prep = InferencePreparation(output_dir=output_dir)
  39. self.results_merger = ResultsMerger(merge_strategy='union', output_dir=output_dir)
  40. def _setup_logger(self) -> logging.Logger:
  41. """Setup main pipeline logger"""
  42. logger = logging.getLogger('DiscoveryPipeline')
  43. logger.setLevel(logging.INFO)
  44. if not logger.handlers:
  45. # Console handler
  46. console_handler = logging.StreamHandler()
  47. console_handler.setLevel(logging.INFO)
  48. # File handler
  49. file_handler = logging.FileHandler(self.output_dir / 'pipeline.log')
  50. file_handler.setLevel(logging.DEBUG)
  51. # Formatter
  52. formatter = logging.Formatter(
  53. '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  54. )
  55. console_handler.setFormatter(formatter)
  56. file_handler.setFormatter(formatter)
  57. logger.addHandler(console_handler)
  58. logger.addHandler(file_handler)
  59. return logger
  60. def run_preprocessing(self):
  61. """Run preprocessing steps (1-6)"""
  62. self.logger.info("=" * 80)
  63. self.logger.info("STARTING PREPROCESSING PIPELINE")
  64. self.logger.info("=" * 80)
  65. # Step 1: Load data
  66. self.logger.info("\nStep 1: Loading data...")
  67. df = self.data_loader.execute()
  68. # Step 2: Create chunks
  69. self.logger.info("\nStep 2: Creating chunks...")
  70. chunks = self.chunk_creator.execute(df)
  71. # Step 3: Keyword filter
  72. self.logger.info("\nStep 3: Applying keyword filter...")
  73. keyword_filtered = self.keyword_filter.execute(chunks)
  74. # Step 4: Semantic filter
  75. self.logger.info("\nStep 4: Applying semantic filter...")
  76. semantic_filtered = self.semantic_filter.execute(keyword_filtered)
  77. # Step 5: Random sampling
  78. self.logger.info("\nStep 5: Random sampling...")
  79. samples = self.random_sampler.execute(semantic_filtered)
  80. # Step 6: Generate labeling template
  81. self.logger.info("\nStep 6: Generating labeling template...")
  82. template_path = self.template_generator.execute(samples)
  83. # Step 7: Prepare inference requests
  84. self.logger.info("\nStep 7: Preparing inference requests...")
  85. requests_path = self.inference_prep.execute(semantic_filtered)
  86. self.logger.info("\n" + "=" * 80)
  87. self.logger.info("PREPROCESSING COMPLETE")
  88. self.logger.info("=" * 80)
  89. self.logger.info(f"\nTotal messages: {len(df):,}")
  90. self.logger.info(f"Total chunks: {len(chunks):,}")
  91. self.logger.info(f"After keyword filter: {len(keyword_filtered):,}")
  92. self.logger.info(f"After semantic filter: {len(semantic_filtered):,}")
  93. self.logger.info(f"Samples for attorney: {len(samples)}")
  94. self.logger.info(f"\nNext steps:")
  95. self.logger.info(f"1. Attorney completes labeling template: {template_path}")
  96. self.logger.info(f"2. Deploy Qwen 3 235B and Qwen 2.5 72B models")
  97. self.logger.info(f"3. Run inference using: {requests_path}")
  98. self.logger.info(f"4. Run merge_results() with inference outputs")
  99. return {
  100. 'df': df,
  101. 'chunks': chunks,
  102. 'keyword_filtered': keyword_filtered,
  103. 'semantic_filtered': semantic_filtered,
  104. 'samples': samples,
  105. 'template_path': template_path,
  106. 'requests_path': requests_path
  107. }
  108. def merge_results(self, qwen3_results_file: str, qwen25_results_file: str):
  109. """Merge results from dual model inference (Step 8)"""
  110. self.logger.info("=" * 80)
  111. self.logger.info("MERGING INFERENCE RESULTS")
  112. self.logger.info("=" * 80)
  113. merged = self.results_merger.execute(qwen3_results_file, qwen25_results_file)
  114. self.logger.info("\n" + "=" * 80)
  115. self.logger.info("MERGE COMPLETE")
  116. self.logger.info("=" * 80)
  117. self.logger.info(f"\nMerged {len(merged)} results")
  118. self.logger.info(f"Results saved to: {self.output_dir / 'merged_results.json'}")
  119. return merged
  120. if __name__ == "__main__":
  121. import argparse
  122. parser = argparse.ArgumentParser(description='Legal Discovery Pipeline')
  123. parser.add_argument('csv_path', help='Path to Signal messages CSV')
  124. parser.add_argument('--output-dir', default='./pipeline_output',
  125. help='Output directory')
  126. parser.add_argument('--step', choices=['preprocess', 'merge'],
  127. default='preprocess',
  128. help='Pipeline step to run')
  129. parser.add_argument('--qwen3-results', help='Qwen 3 results file (for merge)')
  130. parser.add_argument('--qwen25-results', help='Qwen 2.5 results file (for merge)')
  131. args = parser.parse_args()
  132. pipeline = DiscoveryPipeline(args.csv_path, args.output_dir)
  133. if args.step == 'preprocess':
  134. results = pipeline.run_preprocessing()
  135. elif args.step == 'merge':
  136. if not args.qwen3_results or not args.qwen25_results:
  137. print("Error: --qwen3-results and --qwen25-results required for merge step")
  138. sys.exit(1)
  139. results = pipeline.merge_results(args.qwen3_results, args.qwen25_results)