random_sample_selector.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. #!/usr/bin/env python3
  2. """
  3. Random Sample Selector for Attorney Labeling
  4. Selects representative messages from filtered candidates for few-shot learning
  5. """
  6. import pandas as pd
  7. import random
  8. import json
  9. from pathlib import Path
  10. from datetime import datetime
  11. class RandomSampleSelector:
  12. """
  13. Selects random representative samples for attorney labeling.
  14. Ensures diversity across senders, time periods, and keyword matches.
  15. """
  16. def __init__(self, output_dir='./labeling_samples'):
  17. self.output_dir = Path(output_dir)
  18. self.output_dir.mkdir(exist_ok=True)
  19. def select_stratified_sample(self, messages_df, n_samples=20,
  20. stratify_by='sender', seed=42):
  21. """
  22. Select stratified random sample ensuring diversity.
  23. Args:
  24. messages_df: DataFrame with filtered candidate messages
  25. n_samples: Number of samples to select
  26. stratify_by: Column to stratify by ('sender', 'date', etc.)
  27. seed: Random seed for reproducibility
  28. """
  29. random.seed(seed)
  30. print(f"\nSelecting {n_samples} samples stratified by {stratify_by}...")
  31. # Get unique values for stratification
  32. if stratify_by in messages_df.columns:
  33. strata = messages_df[stratify_by].unique()
  34. samples_per_stratum = max(1, n_samples // len(strata))
  35. selected = []
  36. for stratum in strata:
  37. stratum_data = messages_df[messages_df[stratify_by] == stratum]
  38. n_select = min(samples_per_stratum, len(stratum_data))
  39. selected.extend(stratum_data.sample(n=n_select, random_state=seed).to_dict('records'))
  40. # If we need more samples, randomly select from remaining
  41. if len(selected) < n_samples:
  42. remaining = messages_df[~messages_df.index.isin([s['line_number'] for s in selected])]
  43. additional = remaining.sample(n=n_samples - len(selected), random_state=seed)
  44. selected.extend(additional.to_dict('records'))
  45. # Shuffle final selection
  46. random.shuffle(selected)
  47. selected = selected[:n_samples]
  48. else:
  49. # Simple random sample if stratify column doesn't exist
  50. selected = messages_df.sample(n=min(n_samples, len(messages_df)),
  51. random_state=seed).to_dict('records')
  52. print(f"Selected {len(selected)} samples")
  53. return selected
  54. def create_labeling_template(self, samples, context_window=3):
  55. """
  56. Create attorney labeling template with context.
  57. Shows each message with surrounding context for better evaluation.
  58. """
  59. print(f"\nCreating labeling template with context window of {context_window}...")
  60. labeling_data = []
  61. for i, sample in enumerate(samples, 1):
  62. # Create context (would need full dataset to get actual context)
  63. # For now, just format the sample message
  64. entry = {
  65. 'sample_id': i,
  66. 'line_number': sample.get('line_number', i),
  67. 'timestamp': sample.get('timestamp', ''),
  68. 'sender': sample.get('sender', ''),
  69. 'message': sample.get('message', ''),
  70. 'context_before': sample.get('context_before', []),
  71. 'context_after': sample.get('context_after', []),
  72. 'responsive': '', # Attorney fills this
  73. 'reasoning': '', # Attorney fills this
  74. 'criteria_matched': [] # Attorney fills this
  75. }
  76. labeling_data.append(entry)
  77. return labeling_data
  78. def save_labeling_template(self, labeling_data, filename='attorney_labeling_template.json'):
  79. """Save labeling template for attorney"""
  80. filepath = self.output_dir / filename
  81. with open(filepath, 'w') as f:
  82. json.dump(labeling_data, f, indent=2)
  83. print(f"\nLabeling template saved: {filepath}")
  84. # Also create a readable text version
  85. text_filepath = self.output_dir / filename.replace('.json', '.txt')
  86. with open(text_filepath, 'w') as f:
  87. f.write("ATTORNEY LABELING INSTRUCTIONS\n")
  88. f.write("=" * 80 + "\n\n")
  89. f.write("For each message below, please provide:\n")
  90. f.write("1. RESPONSIVE: YES or NO\n")
  91. f.write("2. REASONING: Brief explanation\n")
  92. f.write("3. CRITERIA: Which subpoena criteria matched (1-7)\n\n")
  93. f.write("=" * 80 + "\n\n")
  94. for entry in labeling_data:
  95. f.write(f"SAMPLE {entry['sample_id']}\n")
  96. f.write("-" * 80 + "\n")
  97. f.write(f"Line: {entry['line_number']}\n")
  98. f.write(f"Time: {entry['timestamp']}\n")
  99. f.write(f"Sender: {entry['sender']}\n")
  100. f.write(f"Message: {entry['message']}\n\n")
  101. f.write("RESPONSIVE: _______\n")
  102. f.write("REASONING: _______________________________________\n")
  103. f.write("CRITERIA: _______\n")
  104. f.write("\n" + "=" * 80 + "\n\n")
  105. print(f"Text template saved: {text_filepath}")
  106. return filepath
  107. def load_labeled_samples(self, filepath):
  108. """Load attorney-labeled samples"""
  109. with open(filepath, 'r') as f:
  110. return json.load(f)
  111. def create_few_shot_examples(self, labeled_samples):
  112. """
  113. Convert attorney-labeled samples into few-shot examples for prompts.
  114. """
  115. few_shot_examples = []
  116. for sample in labeled_samples:
  117. if sample.get('responsive'): # Only include if attorney labeled it
  118. example = {
  119. 'message': sample['message'],
  120. 'responsive': sample['responsive'],
  121. 'reasoning': sample['reasoning'],
  122. 'criteria': sample.get('criteria_matched', [])
  123. }
  124. few_shot_examples.append(example)
  125. return few_shot_examples
  126. def format_few_shot_prompt(self, few_shot_examples):
  127. """Format few-shot examples for inclusion in prompts"""
  128. prompt_text = "Here are examples of how to classify messages:\n\n"
  129. for i, example in enumerate(few_shot_examples, 1):
  130. status = "RESPONSIVE" if example['responsive'].upper() == 'YES' else "NOT RESPONSIVE"
  131. prompt_text += f"Example {i} ({status}):\n"
  132. prompt_text += f'Message: "{example["message"]}"\n'
  133. prompt_text += f"Reasoning: {example['reasoning']}\n"
  134. if example.get('criteria'):
  135. prompt_text += f"Criteria matched: {', '.join(map(str, example['criteria']))}\n"
  136. prompt_text += "\n"
  137. return prompt_text
  138. # Example usage
  139. if __name__ == "__main__":
  140. selector = RandomSampleSelector()
  141. # Load filtered candidates (from previous pipeline step)
  142. # candidates_df = pd.read_csv('discovery_output/filtered/candidate_messages.csv')
  143. # Select 20 random samples
  144. # samples = selector.select_stratified_sample(candidates_df, n_samples=20)
  145. # Create labeling template
  146. # labeling_data = selector.create_labeling_template(samples)
  147. # Save for attorney
  148. # selector.save_labeling_template(labeling_data)
  149. print("\nTo use this script:")
  150. print("1. Load your filtered candidate messages")
  151. print("2. Run select_stratified_sample() to get random samples")
  152. print("3. Run create_labeling_template() to format for attorney")
  153. print("4. Attorney labels the samples")
  154. print("5. Run create_few_shot_examples() to convert to prompt format")