step5_random_sampling.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. """
  2. Step 5: Random stratified sampling for attorney labeling.
  3. """
  4. import random
  5. from typing import List
  6. import numpy as np
  7. from pipeline.models.base import PipelineStep
  8. from pipeline.common_defs import Chunk
  9. class RandomSampler(PipelineStep):
  10. """Random stratified sampling for attorney labeling"""
  11. def __init__(
  12. self, n_samples: int = 50, seed: int = 42, output_dir: str = "./pipeline_output"
  13. ):
  14. super().__init__(output_dir)
  15. self.n_samples = n_samples
  16. self.seed = seed
  17. def execute(self, chunks: List[Chunk]) -> List[Chunk]:
  18. """
  19. Select random stratified samples.
  20. Args:
  21. chunks: List of semantically filtered chunks
  22. Returns:
  23. List of sampled chunks
  24. """
  25. self.logger.info(f"Selecting {self.n_samples} random samples...")
  26. self.logger.info(f"Random seed: {self.seed}")
  27. random.seed(self.seed)
  28. # Stratify by semantic score quartiles
  29. scores = [c.semantic_score_combined for c in chunks if c.semantic_score_combined]
  30. if not scores:
  31. self.logger.warning("No semantic scores found, using random sampling")
  32. samples = random.sample(chunks, min(self.n_samples, len(chunks)))
  33. else:
  34. quartiles = np.percentile(scores, [60, 80, 90])
  35. samples = self._stratified_sample(chunks, quartiles)
  36. self.logger.info(f"Selected {len(samples)} samples")
  37. # Save samples
  38. self._save_samples(samples)
  39. return samples
  40. def _stratified_sample(self, chunks: List[Chunk],
  41. quartiles: np.ndarray) -> List[Chunk]:
  42. """Perform stratified sampling by score quartiles"""
  43. samples = []
  44. # Sample from each quartile
  45. for q_low, q_high in [(0, quartiles[0]), (quartiles[0], quartiles[1]),
  46. (quartiles[1], quartiles[2]), (quartiles[2], 1.0)]:
  47. stratum = [
  48. c for c in chunks
  49. if c.semantic_score_combined and
  50. q_low <= c.semantic_score_combined < q_high
  51. ]
  52. if stratum:
  53. n_select = min(self.n_samples // 4, len(stratum))
  54. samples.extend(random.sample(stratum, n_select))
  55. # Fill remaining if needed
  56. if len(samples) < self.n_samples:
  57. remaining = [c for c in chunks if c not in samples]
  58. if remaining:
  59. n_more = min(self.n_samples - len(samples), len(remaining))
  60. samples.extend(random.sample(remaining, n_more))
  61. # Shuffle and limit
  62. random.shuffle(samples)
  63. return samples[:self.n_samples]
  64. def _save_samples(self, samples: List[Chunk]):
  65. """Save sampled chunks"""
  66. samples_data = [
  67. {
  68. "chunk_id": c.chunk_id,
  69. "start_line": c.start_line,
  70. "end_line": c.end_line,
  71. "semantic_score": c.semantic_score_combined,
  72. "num_messages": c.end_line - c.start_line,
  73. }
  74. for c in samples
  75. ]
  76. self.save_results(samples_data, 'random_samples.json')
  77. if __name__ == "__main__":
  78. # Example usage
  79. import json
  80. with open('pipeline_output/semantic_filtered_chunks.json', 'r') as f:
  81. data = json.load(f)
  82. # Reconstruct chunks (simplified for example)
  83. from pipeline.common_defs import Chunk, Message
  84. chunks = []
  85. for item in data['filtered_chunks']:
  86. chunk = Chunk(
  87. chunk_id=item['chunk_id'],
  88. start_line=item['start_line'],
  89. end_line=item['end_line'],
  90. messages=[],
  91. combined_text="",
  92. timestamp_start="",
  93. timestamp_end="",
  94. semantic_score_combined=item['score_combined']
  95. )
  96. chunks.append(chunk)
  97. sampler = RandomSampler(n_samples=100)
  98. samples = sampler.execute(chunks)
  99. print(f"Selected {len(samples)} samples")