#!/usr/bin/env python3 """ Ethical Open-Source Legal Discovery Pipeline Uses Mistral models (French company, no Trump connections) Integrates: dual-model semantic filtering + random sampling + Mistral inference """ import pandas as pd import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import json import random from pathlib import Path from typing import List, Dict import re class EthicalDiscoveryPipeline: """ Complete ethical discovery pipeline using only open-source models from companies with no Trump connections. """ def __init__(self, csv_path: str, output_dir: str = './ethical_discovery_output'): self.csv_path = csv_path self.output_dir = Path(output_dir) self.output_dir.mkdir(exist_ok=True) # Dual embedding models self.embedding_model1 = None self.embedding_model2 = None # Jennifer Capasso v. MSK criteria self.criteria = { 'plaintiff_name': 'Jennifer Capasso', 'plaintiff_variations': [ 'jennifer capasso', 'jen capasso', 'jennifer', 'jen', 'capasso', 'j capasso', 'jc', 'jenny' ], 'facility_names': [ 'memorial sloan kettering', 'msk', 'sloan kettering', 'memorial sloan', 'sloan', 'kettering' ], 'key_topics': [ # Treatment at MSK 'treatment', 'medical care', 'doctor', 'physician', 'nurse', 'appointment', 'visit', 'hospital', 'clinic', 'surgery', 'procedure', 'diagnosis', 'medication', 'prescription', # Complaints 'complaint', 'complain', 'complained', 'issue', 'problem', 'concern', 'patient representative', 'patient advocate', # Patient information updates 'patient information', 'medical records', 'pronouns', 'gender identity', 'gender marker', 'update records', # Discrimination 'discrimination', 'discriminate', 'discriminated', 'bias', 'unfair', 'mistreat', 'transphobia', 'misgendered', 'deadname', 'wrong pronouns', 'refused', 'denied', # March 7, 2022 surgery 'march 7', 'march 2022', '3/7/22', '3/7/2022', 'surgery', # Emotional distress 'emotional distress', 'mental anguish', 'pain', 'suffering', 'trauma', 'anxious', 'depressed', 'stress' ] } def load_and_preprocess(self) -> pd.DataFrame: """Load Signal CSV and preprocess""" print(f"\nLoading Signal chat CSV: {self.csv_path}") df = pd.read_csv(self.csv_path) df.columns = df.columns.str.lower().str.strip() # Add line numbers df['line_number'] = range(1, len(df) + 1) df['message'] = df['message'].fillna('') df['message_normalized'] = df['message'].apply(self.normalize_text) print(f"Loaded {len(df):,} messages") return df def normalize_text(self, text: str) -> str: """Normalize text with abbreviation expansion""" if pd.isna(text) or text == '': return "" text = str(text).lower() expansions = { 'msk': 'memorial sloan kettering', 'dr.': 'doctor', 'dr ': 'doctor ', 'appt': 'appointment', 'hosp': 'hospital', 'med': 'medical', 'rx': 'prescription', 'pt': 'patient', 'pron': 'pronoun' } for abbr, full in expansions.items(): text = text.replace(abbr, full) return text def create_chunks(self, df: pd.DataFrame, chunk_size: int = 20, overlap: int = 5) -> List[Dict]: """Create overlapping chunks""" print(f"\nCreating chunks (size={chunk_size}, overlap={overlap})...") chunks = [] total = len(df) step = chunk_size - overlap for i in range(0, total, step): chunk_df = df.iloc[i:i+chunk_size] if len(chunk_df) == 0: break chunk = { 'chunk_id': len(chunks), 'start_line': int(chunk_df['line_number'].iloc[0]), 'end_line': int(chunk_df['line_number'].iloc[-1]), 'messages': chunk_df.to_dict('records'), 'combined_text': ' '.join(chunk_df['message_normalized'].fillna('')), 'timestamp_start': chunk_df['timestamp'].iloc[0], 'timestamp_end': chunk_df['timestamp'].iloc[-1] } chunks.append(chunk) print(f"Created {len(chunks):,} chunks") return chunks def keyword_filter(self, chunks: List[Dict]) -> List[Dict]: """Filter by keywords""" print("\nApplying keyword filter...") all_keywords = ( self.criteria['plaintiff_variations'] + self.criteria['facility_names'] + self.criteria['key_topics'] ) filtered = [] for chunk in chunks: text = chunk['combined_text'] matches = [kw for kw in all_keywords if kw in text] if matches: chunk['keyword_matches'] = matches chunk['keyword_score'] = len(set(matches)) filtered.append(chunk) reduction = (1 - len(filtered)/len(chunks)) * 100 print(f"Filtered: {len(filtered):,} / {len(chunks):,} chunks ({reduction:.1f}% reduction)") return filtered def dual_semantic_filter(self, chunks: List[Dict], threshold1: float = 0.25, threshold2: float = 0.25, merge_strategy: str = 'union') -> List[Dict]: """ Semantic filtering with two embedding models. Uses union strategy for high recall. """ print("\nApplying dual-model semantic filter...") print(f" Strategy: {merge_strategy}") # Load models if not already loaded if self.embedding_model1 is None: print(" Loading Model 1: all-MiniLM-L6-v2...") self.embedding_model1 = SentenceTransformer('all-MiniLM-L6-v2') if self.embedding_model2 is None: print(" Loading Model 2: all-mpnet-base-v2...") self.embedding_model2 = SentenceTransformer('all-mpnet-base-v2') # Query texts from subpoena criteria queries = [ "Jennifer Capasso treatment at Memorial Sloan Kettering Cancer Center MSK", "complaint to MSK staff about Jennifer Capasso patient care", "update patient pronouns gender identity markers at MSK hospital", "gender markers at other hospitals medical records", "discrimination based on gender identity transgender", "March 7 2022 surgery at MSK Memorial Sloan Kettering", "emotional distress mental anguish pain suffering from medical treatment" ] # Compute query embeddings print(" Computing query embeddings...") query_emb1 = self.embedding_model1.encode(queries) query_emb2 = self.embedding_model2.encode(queries) # Compute chunk embeddings print(f" Computing embeddings for {len(chunks):,} chunks...") chunk_texts = [c['combined_text'] for c in chunks] chunk_emb1 = self.embedding_model1.encode(chunk_texts, show_progress_bar=True, batch_size=32) chunk_emb2 = self.embedding_model2.encode(chunk_texts, show_progress_bar=True, batch_size=32) # Compute similarities print(" Computing semantic similarities...") similarities1 = cosine_similarity(chunk_emb1, query_emb1) similarities2 = cosine_similarity(chunk_emb2, query_emb2) max_sim1 = similarities1.max(axis=1) max_sim2 = similarities2.max(axis=1) # Apply merge strategy filtered = [] for i, chunk in enumerate(chunks): score1 = float(max_sim1[i]) score2 = float(max_sim2[i]) if merge_strategy == 'union': passes = (score1 >= threshold1) or (score2 >= threshold2) combined_score = max(score1, score2) elif merge_strategy == 'intersection': passes = (score1 >= threshold1) and (score2 >= threshold2) combined_score = min(score1, score2) else: # weighted combined_score = 0.4 * score1 + 0.6 * score2 passes = combined_score >= ((0.4 * threshold1 + 0.6 * threshold2)) if passes: chunk['semantic_score_model1'] = score1 chunk['semantic_score_model2'] = score2 chunk['semantic_score_combined'] = combined_score filtered.append(chunk) print(f" Model 1 alone: {(max_sim1 >= threshold1).sum()}") print(f" Model 2 alone: {(max_sim2 >= threshold2).sum()}") print(f" Combined: {len(filtered):,} chunks") print(f" Total reduction: {(1 - len(filtered)/len(chunks))*100:.1f}%") return filtered def select_random_samples(self, chunks: List[Dict], n_samples: int = 20, seed: int = 42) -> List[Dict]: """ Randomly select samples for attorney labeling. Stratifies by semantic score to ensure diversity. """ print(f"\nSelecting {n_samples} random samples for attorney labeling...") random.seed(seed) # Stratify by score quartiles scores = [c.get('semantic_score_combined', 0) for c in chunks] quartiles = np.percentile(scores, [25, 50, 75]) samples = [] for q_low, q_high in [(0, quartiles[0]), (quartiles[0], quartiles[1]), (quartiles[1], quartiles[2]), (quartiles[2], 1.0)]: stratum = [c for c in chunks if q_low <= c.get('semantic_score_combined', 0) < q_high] if stratum: n_select = min(n_samples // 4, len(stratum)) samples.extend(random.sample(stratum, n_select)) # Fill remaining if needed if len(samples) < n_samples: remaining = [c for c in chunks if c not in samples] samples.extend(random.sample(remaining, min(n_samples - len(samples), len(remaining)))) random.shuffle(samples) samples = samples[:n_samples] print(f"Selected {len(samples)} samples across score ranges") return samples def create_labeling_template(self, samples: List[Dict], output_file: str = 'attorney_labeling_template.txt'): """Create attorney-friendly labeling template""" filepath = self.output_dir / output_file with open(filepath, 'w') as f: f.write("ATTORNEY LABELING TEMPLATE\n") f.write("Jennifer Capasso v. Memorial Sloan Kettering Cancer Center\n") f.write("=" * 80 + "\n\n") f.write("INSTRUCTIONS:\n") f.write("For each message below, please provide:\n") f.write("1. RESPONSIVE: YES or NO\n") f.write("2. REASONING: Brief explanation of your decision\n") f.write("3. CRITERIA: Which subpoena criteria matched (1-7):\n") f.write(" 1. Treatment at MSK\n") f.write(" 2. Complaints to MSK staff\n") f.write(" 3. Pronoun/gender marker update requests\n") f.write(" 4. Gender markers at other hospitals\n") f.write(" 5. Prior discrimination (any setting)\n") f.write(" 6. March 7, 2022 surgery\n") f.write(" 7. Emotional distress/economic loss\n\n") f.write("=" * 80 + "\n\n") for i, sample in enumerate(samples, 1): # Get first message from chunk for labeling first_msg = sample['messages'][0] if sample['messages'] else {} f.write(f"SAMPLE {i}\n") f.write("-" * 80 + "\n") f.write(f"Line: {first_msg.get('line_number', 'N/A')}\n") f.write(f"Time: {first_msg.get('timestamp', 'N/A')}\n") f.write(f"Sender: {first_msg.get('sender', 'N/A')}\n") f.write(f"Message: {first_msg.get('message', 'N/A')}\n\n") # Show context (2 messages before and after) f.write("Context (surrounding messages):\n") for j, msg in enumerate(sample['messages'][:5], 1): marker = ">>>" if j == 1 else " " f.write(f"{marker} [{msg.get('sender', '?')}]: {msg.get('message', '')[:80]}...\n") f.write("\n") f.write("RESPONSIVE: _______\n") f.write("REASONING: _____________________________________________\n") f.write("CRITERIA: _______\n") f.write("\n" + "=" * 80 + "\n\n") print(f"\nLabeling template saved: {filepath}") print(f"Please have attorney complete this template and save as:") print(f" {self.output_dir / 'attorney_labels_completed.txt'}") return filepath def save_for_mistral_inference(self, chunks: List[Dict], few_shot_file: str = None): """ Save chunks in format ready for Mistral model inference. Optionally includes few-shot examples from attorney labels. """ print("\nPreparing data for Mistral inference...") # Load few-shot examples if provided few_shot_prompt = "" if few_shot_file and Path(few_shot_file).exists(): print(f" Loading few-shot examples from: {few_shot_file}") # Parse attorney labels (simplified - would need actual parser) few_shot_prompt = "\n\nHere are examples of how to classify messages:\n" few_shot_prompt += "[Attorney-labeled examples would be inserted here]\n" # Create inference requests inference_requests = [] system_prompt = """You are a legal document review specialist analyzing Signal chat messages for a discrimination lawsuit. CASE: Jennifer Capasso v. Memorial Sloan Kettering Cancer Center (MSK) CLAIM: Discrimination based on gender identity SUBPOENA CRITERIA - Messages are responsive if they relate to: 1. Jennifer Capasso's treatment at MSK 2. Complaints to MSK staff about Jennifer Capasso 3. Requests to update Jennifer Capasso's pronouns/gender markers at MSK 4. Gender markers for Jennifer Capasso at other hospitals 5. Prior discrimination Jennifer Capasso experienced (any setting) 6. Jennifer Capasso's March 7, 2022 surgery at MSK 7. Emotional distress/economic loss from MSK treatment IMPORTANT: Err on side of OVER-INCLUSION (high recall).""" for chunk in chunks: messages_text = "" for msg in chunk['messages']: messages_text += f"Line {msg['line_number']} [{msg['sender']}]: {msg['message']}\n" prompt = f"""{system_prompt} {few_shot_prompt} MESSAGES TO REVIEW (Lines {chunk['start_line']}-{chunk['end_line']}): {messages_text} Respond with JSON: {{ "responsive_line_numbers": [list of responsive line numbers], "reasoning": "brief explanation", "confidence": "high/medium/low" }}""" inference_requests.append({ 'chunk_id': chunk['chunk_id'], 'prompt': prompt, 'chunk_data': chunk }) # Save requests requests_file = self.output_dir / 'mistral_inference_requests.jsonl' with open(requests_file, 'w') as f: for req in inference_requests: f.write(json.dumps(req) + '\n') print(f"Saved {len(inference_requests):,} inference requests to: {requests_file}") print("\nNext steps:") print("1. Deploy Mixtral 8x22B on Vast.ai (H100 @ $1.33-1.56/hr)") print("2. Deploy Mistral 7B on Vast.ai (RTX 4090 @ $0.34/hr)") print("3. Run inference on both models") print("4. Merge results (take union for high recall)") return requests_file # Example usage if __name__ == "__main__": # Initialize pipeline pipeline = EthicalDiscoveryPipeline('signal_messages.csv') # Run complete pipeline print("\nETHICAL DISCOVERY PIPELINE") print("Using only Mistral models (French company, no Trump connections)") print("=" * 80) # Step 1: Load and preprocess df = pipeline.load_and_preprocess() # Step 2: Create chunks chunks = pipeline.create_chunks(df, chunk_size=20, overlap=5) # Step 3: Keyword filter keyword_filtered = pipeline.keyword_filter(chunks) # Step 4: Dual-model semantic filter semantic_filtered = pipeline.dual_semantic_filter( keyword_filtered, threshold1=0.25, threshold2=0.25, merge_strategy='union' # High recall ) # Step 5: Select random samples for attorney samples = pipeline.select_random_samples(semantic_filtered, n_samples=20) # Step 6: Create labeling template template_file = pipeline.create_labeling_template(samples) # Step 7: Prepare for Mistral inference requests_file = pipeline.save_for_mistral_inference(semantic_filtered) print("\n" + "=" * 80) print("PIPELINE COMPLETE") print("=" * 80) print(f"\nReduced from {len(df):,} messages to {len(semantic_filtered):,} chunks") print(f"Total reduction: {(1 - len(semantic_filtered)*20/len(df))*100:.1f}%") print(f"\nEstimated cost for Mistral inference:") print(f" Mixtral 8x22B: {len(semantic_filtered) * 0.5 / 60 * 1.45:.2f} (4-8 hours)") print(f" Mistral 7B: {len(semantic_filtered) * 0.3 / 60 * 0.49:.2f} (2-4 hours)") print(f" Total: ${(len(semantic_filtered) * 0.5 / 60 * 1.45) + (len(semantic_filtered) * 0.3 / 60 * 0.49):.2f}")