ethical_discovery_pipeline.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. #!/usr/bin/env python3
  2. """
  3. Ethical Open-Source Legal Discovery Pipeline
  4. Uses Mistral models (French company, no Trump connections)
  5. Integrates: dual-model semantic filtering + random sampling + Mistral inference
  6. """
  7. import pandas as pd
  8. import numpy as np
  9. from sentence_transformers import SentenceTransformer
  10. from sklearn.metrics.pairwise import cosine_similarity
  11. import json
  12. import random
  13. from pathlib import Path
  14. from typing import List, Dict
  15. import re
  16. class EthicalDiscoveryPipeline:
  17. """
  18. Complete ethical discovery pipeline using only open-source models
  19. from companies with no Trump connections.
  20. """
  21. def __init__(self, csv_path: str, output_dir: str = './ethical_discovery_output'):
  22. self.csv_path = csv_path
  23. self.output_dir = Path(output_dir)
  24. self.output_dir.mkdir(exist_ok=True)
  25. # Dual embedding models
  26. self.embedding_model1 = None
  27. self.embedding_model2 = None
  28. # Jennifer Capasso v. MSK criteria
  29. self.criteria = {
  30. 'plaintiff_name': 'Jennifer Capasso',
  31. 'plaintiff_variations': [
  32. 'jennifer capasso', 'jen capasso', 'jennifer', 'jen',
  33. 'capasso', 'j capasso', 'jc', 'jenny'
  34. ],
  35. 'facility_names': [
  36. 'memorial sloan kettering', 'msk', 'sloan kettering',
  37. 'memorial sloan', 'sloan', 'kettering'
  38. ],
  39. 'key_topics': [
  40. # Treatment at MSK
  41. 'treatment', 'medical care', 'doctor', 'physician', 'nurse',
  42. 'appointment', 'visit', 'hospital', 'clinic', 'surgery',
  43. 'procedure', 'diagnosis', 'medication', 'prescription',
  44. # Complaints
  45. 'complaint', 'complain', 'complained', 'issue', 'problem',
  46. 'concern', 'patient representative', 'patient advocate',
  47. # Patient information updates
  48. 'patient information', 'medical records', 'pronouns',
  49. 'gender identity', 'gender marker', 'update records',
  50. # Discrimination
  51. 'discrimination', 'discriminate', 'discriminated',
  52. 'bias', 'unfair', 'mistreat', 'transphobia', 'misgendered',
  53. 'deadname', 'wrong pronouns', 'refused', 'denied',
  54. # March 7, 2022 surgery
  55. 'march 7', 'march 2022', '3/7/22', '3/7/2022', 'surgery',
  56. # Emotional distress
  57. 'emotional distress', 'mental anguish', 'pain', 'suffering',
  58. 'trauma', 'anxious', 'depressed', 'stress'
  59. ]
  60. }
  61. def load_and_preprocess(self) -> pd.DataFrame:
  62. """Load Signal CSV and preprocess"""
  63. print(f"\nLoading Signal chat CSV: {self.csv_path}")
  64. df = pd.read_csv(self.csv_path)
  65. df.columns = df.columns.str.lower().str.strip()
  66. # Add line numbers
  67. df['line_number'] = range(1, len(df) + 1)
  68. df['message'] = df['message'].fillna('')
  69. df['message_normalized'] = df['message'].apply(self.normalize_text)
  70. print(f"Loaded {len(df):,} messages")
  71. return df
  72. def normalize_text(self, text: str) -> str:
  73. """Normalize text with abbreviation expansion"""
  74. if pd.isna(text) or text == '':
  75. return ""
  76. text = str(text).lower()
  77. expansions = {
  78. 'msk': 'memorial sloan kettering',
  79. 'dr.': 'doctor', 'dr ': 'doctor ',
  80. 'appt': 'appointment', 'hosp': 'hospital',
  81. 'med': 'medical', 'rx': 'prescription',
  82. 'pt': 'patient', 'pron': 'pronoun'
  83. }
  84. for abbr, full in expansions.items():
  85. text = text.replace(abbr, full)
  86. return text
  87. def create_chunks(self, df: pd.DataFrame, chunk_size: int = 20,
  88. overlap: int = 5) -> List[Dict]:
  89. """Create overlapping chunks"""
  90. print(f"\nCreating chunks (size={chunk_size}, overlap={overlap})...")
  91. chunks = []
  92. total = len(df)
  93. step = chunk_size - overlap
  94. for i in range(0, total, step):
  95. chunk_df = df.iloc[i:i+chunk_size]
  96. if len(chunk_df) == 0:
  97. break
  98. chunk = {
  99. 'chunk_id': len(chunks),
  100. 'start_line': int(chunk_df['line_number'].iloc[0]),
  101. 'end_line': int(chunk_df['line_number'].iloc[-1]),
  102. 'messages': chunk_df.to_dict('records'),
  103. 'combined_text': ' '.join(chunk_df['message_normalized'].fillna('')),
  104. 'timestamp_start': chunk_df['timestamp'].iloc[0],
  105. 'timestamp_end': chunk_df['timestamp'].iloc[-1]
  106. }
  107. chunks.append(chunk)
  108. print(f"Created {len(chunks):,} chunks")
  109. return chunks
  110. def keyword_filter(self, chunks: List[Dict]) -> List[Dict]:
  111. """Filter by keywords"""
  112. print("\nApplying keyword filter...")
  113. all_keywords = (
  114. self.criteria['plaintiff_variations'] +
  115. self.criteria['facility_names'] +
  116. self.criteria['key_topics']
  117. )
  118. filtered = []
  119. for chunk in chunks:
  120. text = chunk['combined_text']
  121. matches = [kw for kw in all_keywords if kw in text]
  122. if matches:
  123. chunk['keyword_matches'] = matches
  124. chunk['keyword_score'] = len(set(matches))
  125. filtered.append(chunk)
  126. reduction = (1 - len(filtered)/len(chunks)) * 100
  127. print(f"Filtered: {len(filtered):,} / {len(chunks):,} chunks ({reduction:.1f}% reduction)")
  128. return filtered
  129. def dual_semantic_filter(self, chunks: List[Dict],
  130. threshold1: float = 0.25,
  131. threshold2: float = 0.25,
  132. merge_strategy: str = 'union') -> List[Dict]:
  133. """
  134. Semantic filtering with two embedding models.
  135. Uses union strategy for high recall.
  136. """
  137. print("\nApplying dual-model semantic filter...")
  138. print(f" Strategy: {merge_strategy}")
  139. # Load models if not already loaded
  140. if self.embedding_model1 is None:
  141. print(" Loading Model 1: all-MiniLM-L6-v2...")
  142. self.embedding_model1 = SentenceTransformer('all-MiniLM-L6-v2')
  143. if self.embedding_model2 is None:
  144. print(" Loading Model 2: all-mpnet-base-v2...")
  145. self.embedding_model2 = SentenceTransformer('all-mpnet-base-v2')
  146. # Query texts from subpoena criteria
  147. queries = [
  148. "Jennifer Capasso treatment at Memorial Sloan Kettering Cancer Center MSK",
  149. "complaint to MSK staff about Jennifer Capasso patient care",
  150. "update patient pronouns gender identity markers at MSK hospital",
  151. "gender markers at other hospitals medical records",
  152. "discrimination based on gender identity transgender",
  153. "March 7 2022 surgery at MSK Memorial Sloan Kettering",
  154. "emotional distress mental anguish pain suffering from medical treatment"
  155. ]
  156. # Compute query embeddings
  157. print(" Computing query embeddings...")
  158. query_emb1 = self.embedding_model1.encode(queries)
  159. query_emb2 = self.embedding_model2.encode(queries)
  160. # Compute chunk embeddings
  161. print(f" Computing embeddings for {len(chunks):,} chunks...")
  162. chunk_texts = [c['combined_text'] for c in chunks]
  163. chunk_emb1 = self.embedding_model1.encode(chunk_texts, show_progress_bar=True, batch_size=32)
  164. chunk_emb2 = self.embedding_model2.encode(chunk_texts, show_progress_bar=True, batch_size=32)
  165. # Compute similarities
  166. print(" Computing semantic similarities...")
  167. similarities1 = cosine_similarity(chunk_emb1, query_emb1)
  168. similarities2 = cosine_similarity(chunk_emb2, query_emb2)
  169. max_sim1 = similarities1.max(axis=1)
  170. max_sim2 = similarities2.max(axis=1)
  171. # Apply merge strategy
  172. filtered = []
  173. for i, chunk in enumerate(chunks):
  174. score1 = float(max_sim1[i])
  175. score2 = float(max_sim2[i])
  176. if merge_strategy == 'union':
  177. passes = (score1 >= threshold1) or (score2 >= threshold2)
  178. combined_score = max(score1, score2)
  179. elif merge_strategy == 'intersection':
  180. passes = (score1 >= threshold1) and (score2 >= threshold2)
  181. combined_score = min(score1, score2)
  182. else: # weighted
  183. combined_score = 0.4 * score1 + 0.6 * score2
  184. passes = combined_score >= ((0.4 * threshold1 + 0.6 * threshold2))
  185. if passes:
  186. chunk['semantic_score_model1'] = score1
  187. chunk['semantic_score_model2'] = score2
  188. chunk['semantic_score_combined'] = combined_score
  189. filtered.append(chunk)
  190. print(f" Model 1 alone: {(max_sim1 >= threshold1).sum()}")
  191. print(f" Model 2 alone: {(max_sim2 >= threshold2).sum()}")
  192. print(f" Combined: {len(filtered):,} chunks")
  193. print(f" Total reduction: {(1 - len(filtered)/len(chunks))*100:.1f}%")
  194. return filtered
  195. def select_random_samples(self, chunks: List[Dict], n_samples: int = 20,
  196. seed: int = 42) -> List[Dict]:
  197. """
  198. Randomly select samples for attorney labeling.
  199. Stratifies by semantic score to ensure diversity.
  200. """
  201. print(f"\nSelecting {n_samples} random samples for attorney labeling...")
  202. random.seed(seed)
  203. # Stratify by score quartiles
  204. scores = [c.get('semantic_score_combined', 0) for c in chunks]
  205. quartiles = np.percentile(scores, [25, 50, 75])
  206. samples = []
  207. for q_low, q_high in [(0, quartiles[0]), (quartiles[0], quartiles[1]),
  208. (quartiles[1], quartiles[2]), (quartiles[2], 1.0)]:
  209. stratum = [c for c in chunks if q_low <= c.get('semantic_score_combined', 0) < q_high]
  210. if stratum:
  211. n_select = min(n_samples // 4, len(stratum))
  212. samples.extend(random.sample(stratum, n_select))
  213. # Fill remaining if needed
  214. if len(samples) < n_samples:
  215. remaining = [c for c in chunks if c not in samples]
  216. samples.extend(random.sample(remaining, min(n_samples - len(samples), len(remaining))))
  217. random.shuffle(samples)
  218. samples = samples[:n_samples]
  219. print(f"Selected {len(samples)} samples across score ranges")
  220. return samples
  221. def create_labeling_template(self, samples: List[Dict],
  222. output_file: str = 'attorney_labeling_template.txt'):
  223. """Create attorney-friendly labeling template"""
  224. filepath = self.output_dir / output_file
  225. with open(filepath, 'w') as f:
  226. f.write("ATTORNEY LABELING TEMPLATE\n")
  227. f.write("Jennifer Capasso v. Memorial Sloan Kettering Cancer Center\n")
  228. f.write("=" * 80 + "\n\n")
  229. f.write("INSTRUCTIONS:\n")
  230. f.write("For each message below, please provide:\n")
  231. f.write("1. RESPONSIVE: YES or NO\n")
  232. f.write("2. REASONING: Brief explanation of your decision\n")
  233. f.write("3. CRITERIA: Which subpoena criteria matched (1-7):\n")
  234. f.write(" 1. Treatment at MSK\n")
  235. f.write(" 2. Complaints to MSK staff\n")
  236. f.write(" 3. Pronoun/gender marker update requests\n")
  237. f.write(" 4. Gender markers at other hospitals\n")
  238. f.write(" 5. Prior discrimination (any setting)\n")
  239. f.write(" 6. March 7, 2022 surgery\n")
  240. f.write(" 7. Emotional distress/economic loss\n\n")
  241. f.write("=" * 80 + "\n\n")
  242. for i, sample in enumerate(samples, 1):
  243. # Get first message from chunk for labeling
  244. first_msg = sample['messages'][0] if sample['messages'] else {}
  245. f.write(f"SAMPLE {i}\n")
  246. f.write("-" * 80 + "\n")
  247. f.write(f"Line: {first_msg.get('line_number', 'N/A')}\n")
  248. f.write(f"Time: {first_msg.get('timestamp', 'N/A')}\n")
  249. f.write(f"Sender: {first_msg.get('sender', 'N/A')}\n")
  250. f.write(f"Message: {first_msg.get('message', 'N/A')}\n\n")
  251. # Show context (2 messages before and after)
  252. f.write("Context (surrounding messages):\n")
  253. for j, msg in enumerate(sample['messages'][:5], 1):
  254. marker = ">>>" if j == 1 else " "
  255. f.write(f"{marker} [{msg.get('sender', '?')}]: {msg.get('message', '')[:80]}...\n")
  256. f.write("\n")
  257. f.write("RESPONSIVE: _______\n")
  258. f.write("REASONING: _____________________________________________\n")
  259. f.write("CRITERIA: _______\n")
  260. f.write("\n" + "=" * 80 + "\n\n")
  261. print(f"\nLabeling template saved: {filepath}")
  262. print(f"Please have attorney complete this template and save as:")
  263. print(f" {self.output_dir / 'attorney_labels_completed.txt'}")
  264. return filepath
  265. def save_for_mistral_inference(self, chunks: List[Dict],
  266. few_shot_file: str = None):
  267. """
  268. Save chunks in format ready for Mistral model inference.
  269. Optionally includes few-shot examples from attorney labels.
  270. """
  271. print("\nPreparing data for Mistral inference...")
  272. # Load few-shot examples if provided
  273. few_shot_prompt = ""
  274. if few_shot_file and Path(few_shot_file).exists():
  275. print(f" Loading few-shot examples from: {few_shot_file}")
  276. # Parse attorney labels (simplified - would need actual parser)
  277. few_shot_prompt = "\n\nHere are examples of how to classify messages:\n"
  278. few_shot_prompt += "[Attorney-labeled examples would be inserted here]\n"
  279. # Create inference requests
  280. inference_requests = []
  281. system_prompt = """You are a legal document review specialist analyzing Signal chat messages for a discrimination lawsuit.
  282. CASE: Jennifer Capasso v. Memorial Sloan Kettering Cancer Center (MSK)
  283. CLAIM: Discrimination based on gender identity
  284. SUBPOENA CRITERIA - Messages are responsive if they relate to:
  285. 1. Jennifer Capasso's treatment at MSK
  286. 2. Complaints to MSK staff about Jennifer Capasso
  287. 3. Requests to update Jennifer Capasso's pronouns/gender markers at MSK
  288. 4. Gender markers for Jennifer Capasso at other hospitals
  289. 5. Prior discrimination Jennifer Capasso experienced (any setting)
  290. 6. Jennifer Capasso's March 7, 2022 surgery at MSK
  291. 7. Emotional distress/economic loss from MSK treatment
  292. IMPORTANT: Err on side of OVER-INCLUSION (high recall)."""
  293. for chunk in chunks:
  294. messages_text = ""
  295. for msg in chunk['messages']:
  296. messages_text += f"Line {msg['line_number']} [{msg['sender']}]: {msg['message']}\n"
  297. prompt = f"""{system_prompt}
  298. {few_shot_prompt}
  299. MESSAGES TO REVIEW (Lines {chunk['start_line']}-{chunk['end_line']}):
  300. {messages_text}
  301. Respond with JSON:
  302. {{
  303. "responsive_line_numbers": [list of responsive line numbers],
  304. "reasoning": "brief explanation",
  305. "confidence": "high/medium/low"
  306. }}"""
  307. inference_requests.append({
  308. 'chunk_id': chunk['chunk_id'],
  309. 'prompt': prompt,
  310. 'chunk_data': chunk
  311. })
  312. # Save requests
  313. requests_file = self.output_dir / 'mistral_inference_requests.jsonl'
  314. with open(requests_file, 'w') as f:
  315. for req in inference_requests:
  316. f.write(json.dumps(req) + '\n')
  317. print(f"Saved {len(inference_requests):,} inference requests to: {requests_file}")
  318. print("\nNext steps:")
  319. print("1. Deploy Mixtral 8x22B on Vast.ai (H100 @ $1.33-1.56/hr)")
  320. print("2. Deploy Mistral 7B on Vast.ai (RTX 4090 @ $0.34/hr)")
  321. print("3. Run inference on both models")
  322. print("4. Merge results (take union for high recall)")
  323. return requests_file
  324. # Example usage
  325. if __name__ == "__main__":
  326. # Initialize pipeline
  327. pipeline = EthicalDiscoveryPipeline('signal_messages.csv')
  328. # Run complete pipeline
  329. print("\nETHICAL DISCOVERY PIPELINE")
  330. print("Using only Mistral models (French company, no Trump connections)")
  331. print("=" * 80)
  332. # Step 1: Load and preprocess
  333. df = pipeline.load_and_preprocess()
  334. # Step 2: Create chunks
  335. chunks = pipeline.create_chunks(df, chunk_size=20, overlap=5)
  336. # Step 3: Keyword filter
  337. keyword_filtered = pipeline.keyword_filter(chunks)
  338. # Step 4: Dual-model semantic filter
  339. semantic_filtered = pipeline.dual_semantic_filter(
  340. keyword_filtered,
  341. threshold1=0.25,
  342. threshold2=0.25,
  343. merge_strategy='union' # High recall
  344. )
  345. # Step 5: Select random samples for attorney
  346. samples = pipeline.select_random_samples(semantic_filtered, n_samples=20)
  347. # Step 6: Create labeling template
  348. template_file = pipeline.create_labeling_template(samples)
  349. # Step 7: Prepare for Mistral inference
  350. requests_file = pipeline.save_for_mistral_inference(semantic_filtered)
  351. print("\n" + "=" * 80)
  352. print("PIPELINE COMPLETE")
  353. print("=" * 80)
  354. print(f"\nReduced from {len(df):,} messages to {len(semantic_filtered):,} chunks")
  355. print(f"Total reduction: {(1 - len(semantic_filtered)*20/len(df))*100:.1f}%")
  356. print(f"\nEstimated cost for Mistral inference:")
  357. print(f" Mixtral 8x22B: {len(semantic_filtered) * 0.5 / 60 * 1.45:.2f} (4-8 hours)")
  358. print(f" Mistral 7B: {len(semantic_filtered) * 0.3 / 60 * 0.49:.2f} (2-4 hours)")
  359. print(f" Total: ${(len(semantic_filtered) * 0.5 / 60 * 1.45) + (len(semantic_filtered) * 0.3 / 60 * 0.49):.2f}")