step5_random_sampling.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. """
  2. Step 5: Random stratified sampling for attorney labeling.
  3. """
  4. import random
  5. from typing import List
  6. import numpy as np
  7. from pipeline.models.base import PipelineStep
  8. from pipeline.common_defs import Chunk
  9. class RandomSampler(PipelineStep):
  10. """Random stratified sampling for attorney labeling"""
  11. def __init__(self, n_samples: int = 20, seed: int = 42,
  12. output_dir: str = './pipeline_output'):
  13. super().__init__(output_dir)
  14. self.n_samples = n_samples
  15. self.seed = seed
  16. def execute(self, chunks: List[Chunk]) -> List[Chunk]:
  17. """
  18. Select random stratified samples.
  19. Args:
  20. chunks: List of semantically filtered chunks
  21. Returns:
  22. List of sampled chunks
  23. """
  24. self.logger.info(f"Selecting {self.n_samples} random samples...")
  25. self.logger.info(f"Random seed: {self.seed}")
  26. random.seed(self.seed)
  27. # Stratify by semantic score quartiles
  28. scores = [c.semantic_score_combined for c in chunks if c.semantic_score_combined]
  29. if not scores:
  30. self.logger.warning("No semantic scores found, using random sampling")
  31. samples = random.sample(chunks, min(self.n_samples, len(chunks)))
  32. else:
  33. quartiles = np.percentile(scores, [25, 50, 75])
  34. samples = self._stratified_sample(chunks, quartiles)
  35. self.logger.info(f"Selected {len(samples)} samples")
  36. # Save samples
  37. self._save_samples(samples)
  38. return samples
  39. def _stratified_sample(self, chunks: List[Chunk],
  40. quartiles: np.ndarray) -> List[Chunk]:
  41. """Perform stratified sampling by score quartiles"""
  42. samples = []
  43. # Sample from each quartile
  44. for q_low, q_high in [(0, quartiles[0]), (quartiles[0], quartiles[1]),
  45. (quartiles[1], quartiles[2]), (quartiles[2], 1.0)]:
  46. stratum = [
  47. c for c in chunks
  48. if c.semantic_score_combined and
  49. q_low <= c.semantic_score_combined < q_high
  50. ]
  51. if stratum:
  52. n_select = min(self.n_samples // 4, len(stratum))
  53. samples.extend(random.sample(stratum, n_select))
  54. # Fill remaining if needed
  55. if len(samples) < self.n_samples:
  56. remaining = [c for c in chunks if c not in samples]
  57. if remaining:
  58. n_more = min(self.n_samples - len(samples), len(remaining))
  59. samples.extend(random.sample(remaining, n_more))
  60. # Shuffle and limit
  61. random.shuffle(samples)
  62. return samples[:self.n_samples]
  63. def _save_samples(self, samples: List[Chunk]):
  64. """Save sampled chunks"""
  65. samples_data = [
  66. {
  67. 'chunk_id': c.chunk_id,
  68. 'start_line': c.start_line,
  69. 'end_line': c.end_line,
  70. 'semantic_score': c.semantic_score_combined,
  71. 'num_messages': len(c.messages)
  72. }
  73. for c in samples
  74. ]
  75. self.save_results(samples_data, 'random_samples.json')
  76. if __name__ == "__main__":
  77. # Example usage
  78. import json
  79. with open('pipeline_output/semantic_filtered_chunks.json', 'r') as f:
  80. data = json.load(f)
  81. # Reconstruct chunks (simplified for example)
  82. from pipeline.common_defs import Chunk, Message
  83. chunks = []
  84. for item in data['filtered_chunks']:
  85. chunk = Chunk(
  86. chunk_id=item['chunk_id'],
  87. start_line=item['start_line'],
  88. end_line=item['end_line'],
  89. messages=[],
  90. combined_text="",
  91. timestamp_start="",
  92. timestamp_end="",
  93. semantic_score_combined=item['score_combined']
  94. )
  95. chunks.append(chunk)
  96. sampler = RandomSampler(n_samples=20)
  97. samples = sampler.execute(chunks)
  98. print(f"Selected {len(samples)} samples")