| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- """
- Step 5: Random stratified sampling for attorney labeling.
- """
- import random
- from typing import List
- import numpy as np
- from pipeline.models.base import PipelineStep
- from pipeline.common_defs import Chunk
- class RandomSampler(PipelineStep):
- """Random stratified sampling for attorney labeling"""
- def __init__(
- self, n_samples: int = 50, seed: int = 42, output_dir: str = "./pipeline_output"
- ):
- super().__init__(output_dir)
- self.n_samples = n_samples
- self.seed = seed
- def execute(self, chunks: List[Chunk]) -> List[Chunk]:
- """
- Select random stratified samples.
-
- Args:
- chunks: List of semantically filtered chunks
-
- Returns:
- List of sampled chunks
- """
- self.logger.info(f"Selecting {self.n_samples} random samples...")
- self.logger.info(f"Random seed: {self.seed}")
- random.seed(self.seed)
- # Stratify by semantic score quartiles
- scores = [c.semantic_score_combined for c in chunks if c.semantic_score_combined]
- if not scores:
- self.logger.warning("No semantic scores found, using random sampling")
- samples = random.sample(chunks, min(self.n_samples, len(chunks)))
- else:
- quartiles = np.percentile(scores, [60, 80, 90])
- samples = self._stratified_sample(chunks, quartiles)
- self.logger.info(f"Selected {len(samples)} samples")
- # Save samples
- self._save_samples(samples)
- return samples
- def _stratified_sample(self, chunks: List[Chunk],
- quartiles: np.ndarray) -> List[Chunk]:
- """Perform stratified sampling by score quartiles"""
- samples = []
- # Sample from each quartile
- for q_low, q_high in [(0, quartiles[0]), (quartiles[0], quartiles[1]),
- (quartiles[1], quartiles[2]), (quartiles[2], 1.0)]:
- stratum = [
- c for c in chunks
- if c.semantic_score_combined and
- q_low <= c.semantic_score_combined < q_high
- ]
- if stratum:
- n_select = min(self.n_samples // 4, len(stratum))
- samples.extend(random.sample(stratum, n_select))
- # Fill remaining if needed
- if len(samples) < self.n_samples:
- remaining = [c for c in chunks if c not in samples]
- if remaining:
- n_more = min(self.n_samples - len(samples), len(remaining))
- samples.extend(random.sample(remaining, n_more))
- # Shuffle and limit
- random.shuffle(samples)
- return samples[:self.n_samples]
- def _save_samples(self, samples: List[Chunk]):
- """Save sampled chunks"""
- samples_data = [
- {
- "chunk_id": c.chunk_id,
- "start_line": c.start_line,
- "end_line": c.end_line,
- "semantic_score": c.semantic_score_combined,
- "num_messages": c.end_line - c.start_line,
- }
- for c in samples
- ]
- self.save_results(samples_data, 'random_samples.json')
- if __name__ == "__main__":
- # Example usage
- import json
- with open('pipeline_output/semantic_filtered_chunks.json', 'r') as f:
- data = json.load(f)
- # Reconstruct chunks (simplified for example)
- from pipeline.common_defs import Chunk, Message
- chunks = []
- for item in data['filtered_chunks']:
- chunk = Chunk(
- chunk_id=item['chunk_id'],
- start_line=item['start_line'],
- end_line=item['end_line'],
- messages=[],
- combined_text="",
- timestamp_start="",
- timestamp_end="",
- semantic_score_combined=item['score_combined']
- )
- chunks.append(chunk)
- sampler = RandomSampler(n_samples=100)
- samples = sampler.execute(chunks)
- print(f"Selected {len(samples)} samples")
|