| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447 |
- #!/usr/bin/env python3
- """
- Ethical Open-Source Legal Discovery Pipeline
- Uses Mistral models (French company, no Trump connections)
- Integrates: dual-model semantic filtering + random sampling + Mistral inference
- """
- import pandas as pd
- import numpy as np
- from sentence_transformers import SentenceTransformer
- from sklearn.metrics.pairwise import cosine_similarity
- import json
- import random
- from pathlib import Path
- from typing import List, Dict
- import re
- class EthicalDiscoveryPipeline:
- """
- Complete ethical discovery pipeline using only open-source models
- from companies with no Trump connections.
- """
-
- def __init__(self, csv_path: str, output_dir: str = './ethical_discovery_output'):
- self.csv_path = csv_path
- self.output_dir = Path(output_dir)
- self.output_dir.mkdir(exist_ok=True)
-
- # Dual embedding models
- self.embedding_model1 = None
- self.embedding_model2 = None
-
- # Jennifer Capasso v. MSK criteria
- self.criteria = {
- 'plaintiff_name': 'Jennifer Capasso',
- 'plaintiff_variations': [
- 'jennifer capasso', 'jen capasso', 'jennifer', 'jen',
- 'capasso', 'j capasso', 'jc', 'jenny'
- ],
- 'facility_names': [
- 'memorial sloan kettering', 'msk', 'sloan kettering',
- 'memorial sloan', 'sloan', 'kettering'
- ],
- 'key_topics': [
- # Treatment at MSK
- 'treatment', 'medical care', 'doctor', 'physician', 'nurse',
- 'appointment', 'visit', 'hospital', 'clinic', 'surgery',
- 'procedure', 'diagnosis', 'medication', 'prescription',
-
- # Complaints
- 'complaint', 'complain', 'complained', 'issue', 'problem',
- 'concern', 'patient representative', 'patient advocate',
-
- # Patient information updates
- 'patient information', 'medical records', 'pronouns',
- 'gender identity', 'gender marker', 'update records',
-
- # Discrimination
- 'discrimination', 'discriminate', 'discriminated',
- 'bias', 'unfair', 'mistreat', 'transphobia', 'misgendered',
- 'deadname', 'wrong pronouns', 'refused', 'denied',
-
- # March 7, 2022 surgery
- 'march 7', 'march 2022', '3/7/22', '3/7/2022', 'surgery',
-
- # Emotional distress
- 'emotional distress', 'mental anguish', 'pain', 'suffering',
- 'trauma', 'anxious', 'depressed', 'stress'
- ]
- }
-
- def load_and_preprocess(self) -> pd.DataFrame:
- """Load Signal CSV and preprocess"""
- print(f"\nLoading Signal chat CSV: {self.csv_path}")
-
- df = pd.read_csv(self.csv_path)
- df.columns = df.columns.str.lower().str.strip()
-
- # Add line numbers
- df['line_number'] = range(1, len(df) + 1)
- df['message'] = df['message'].fillna('')
- df['message_normalized'] = df['message'].apply(self.normalize_text)
-
- print(f"Loaded {len(df):,} messages")
- return df
-
- def normalize_text(self, text: str) -> str:
- """Normalize text with abbreviation expansion"""
- if pd.isna(text) or text == '':
- return ""
-
- text = str(text).lower()
-
- expansions = {
- 'msk': 'memorial sloan kettering',
- 'dr.': 'doctor', 'dr ': 'doctor ',
- 'appt': 'appointment', 'hosp': 'hospital',
- 'med': 'medical', 'rx': 'prescription',
- 'pt': 'patient', 'pron': 'pronoun'
- }
-
- for abbr, full in expansions.items():
- text = text.replace(abbr, full)
-
- return text
-
- def create_chunks(self, df: pd.DataFrame, chunk_size: int = 20,
- overlap: int = 5) -> List[Dict]:
- """Create overlapping chunks"""
- print(f"\nCreating chunks (size={chunk_size}, overlap={overlap})...")
-
- chunks = []
- total = len(df)
- step = chunk_size - overlap
-
- for i in range(0, total, step):
- chunk_df = df.iloc[i:i+chunk_size]
- if len(chunk_df) == 0:
- break
-
- chunk = {
- 'chunk_id': len(chunks),
- 'start_line': int(chunk_df['line_number'].iloc[0]),
- 'end_line': int(chunk_df['line_number'].iloc[-1]),
- 'messages': chunk_df.to_dict('records'),
- 'combined_text': ' '.join(chunk_df['message_normalized'].fillna('')),
- 'timestamp_start': chunk_df['timestamp'].iloc[0],
- 'timestamp_end': chunk_df['timestamp'].iloc[-1]
- }
- chunks.append(chunk)
-
- print(f"Created {len(chunks):,} chunks")
- return chunks
-
- def keyword_filter(self, chunks: List[Dict]) -> List[Dict]:
- """Filter by keywords"""
- print("\nApplying keyword filter...")
-
- all_keywords = (
- self.criteria['plaintiff_variations'] +
- self.criteria['facility_names'] +
- self.criteria['key_topics']
- )
-
- filtered = []
- for chunk in chunks:
- text = chunk['combined_text']
- matches = [kw for kw in all_keywords if kw in text]
-
- if matches:
- chunk['keyword_matches'] = matches
- chunk['keyword_score'] = len(set(matches))
- filtered.append(chunk)
-
- reduction = (1 - len(filtered)/len(chunks)) * 100
- print(f"Filtered: {len(filtered):,} / {len(chunks):,} chunks ({reduction:.1f}% reduction)")
-
- return filtered
-
- def dual_semantic_filter(self, chunks: List[Dict],
- threshold1: float = 0.25,
- threshold2: float = 0.25,
- merge_strategy: str = 'union') -> List[Dict]:
- """
- Semantic filtering with two embedding models.
- Uses union strategy for high recall.
- """
- print("\nApplying dual-model semantic filter...")
- print(f" Strategy: {merge_strategy}")
-
- # Load models if not already loaded
- if self.embedding_model1 is None:
- print(" Loading Model 1: all-MiniLM-L6-v2...")
- self.embedding_model1 = SentenceTransformer('all-MiniLM-L6-v2')
-
- if self.embedding_model2 is None:
- print(" Loading Model 2: all-mpnet-base-v2...")
- self.embedding_model2 = SentenceTransformer('all-mpnet-base-v2')
-
- # Query texts from subpoena criteria
- queries = [
- "Jennifer Capasso treatment at Memorial Sloan Kettering Cancer Center MSK",
- "complaint to MSK staff about Jennifer Capasso patient care",
- "update patient pronouns gender identity markers at MSK hospital",
- "gender markers at other hospitals medical records",
- "discrimination based on gender identity transgender",
- "March 7 2022 surgery at MSK Memorial Sloan Kettering",
- "emotional distress mental anguish pain suffering from medical treatment"
- ]
-
- # Compute query embeddings
- print(" Computing query embeddings...")
- query_emb1 = self.embedding_model1.encode(queries)
- query_emb2 = self.embedding_model2.encode(queries)
-
- # Compute chunk embeddings
- print(f" Computing embeddings for {len(chunks):,} chunks...")
- chunk_texts = [c['combined_text'] for c in chunks]
-
- chunk_emb1 = self.embedding_model1.encode(chunk_texts, show_progress_bar=True, batch_size=32)
- chunk_emb2 = self.embedding_model2.encode(chunk_texts, show_progress_bar=True, batch_size=32)
-
- # Compute similarities
- print(" Computing semantic similarities...")
- similarities1 = cosine_similarity(chunk_emb1, query_emb1)
- similarities2 = cosine_similarity(chunk_emb2, query_emb2)
-
- max_sim1 = similarities1.max(axis=1)
- max_sim2 = similarities2.max(axis=1)
-
- # Apply merge strategy
- filtered = []
- for i, chunk in enumerate(chunks):
- score1 = float(max_sim1[i])
- score2 = float(max_sim2[i])
-
- if merge_strategy == 'union':
- passes = (score1 >= threshold1) or (score2 >= threshold2)
- combined_score = max(score1, score2)
- elif merge_strategy == 'intersection':
- passes = (score1 >= threshold1) and (score2 >= threshold2)
- combined_score = min(score1, score2)
- else: # weighted
- combined_score = 0.4 * score1 + 0.6 * score2
- passes = combined_score >= ((0.4 * threshold1 + 0.6 * threshold2))
-
- if passes:
- chunk['semantic_score_model1'] = score1
- chunk['semantic_score_model2'] = score2
- chunk['semantic_score_combined'] = combined_score
- filtered.append(chunk)
-
- print(f" Model 1 alone: {(max_sim1 >= threshold1).sum()}")
- print(f" Model 2 alone: {(max_sim2 >= threshold2).sum()}")
- print(f" Combined: {len(filtered):,} chunks")
- print(f" Total reduction: {(1 - len(filtered)/len(chunks))*100:.1f}%")
-
- return filtered
-
- def select_random_samples(self, chunks: List[Dict], n_samples: int = 20,
- seed: int = 42) -> List[Dict]:
- """
- Randomly select samples for attorney labeling.
- Stratifies by semantic score to ensure diversity.
- """
- print(f"\nSelecting {n_samples} random samples for attorney labeling...")
-
- random.seed(seed)
-
- # Stratify by score quartiles
- scores = [c.get('semantic_score_combined', 0) for c in chunks]
- quartiles = np.percentile(scores, [25, 50, 75])
-
- samples = []
- for q_low, q_high in [(0, quartiles[0]), (quartiles[0], quartiles[1]),
- (quartiles[1], quartiles[2]), (quartiles[2], 1.0)]:
- stratum = [c for c in chunks if q_low <= c.get('semantic_score_combined', 0) < q_high]
- if stratum:
- n_select = min(n_samples // 4, len(stratum))
- samples.extend(random.sample(stratum, n_select))
-
- # Fill remaining if needed
- if len(samples) < n_samples:
- remaining = [c for c in chunks if c not in samples]
- samples.extend(random.sample(remaining, min(n_samples - len(samples), len(remaining))))
-
- random.shuffle(samples)
- samples = samples[:n_samples]
-
- print(f"Selected {len(samples)} samples across score ranges")
- return samples
-
- def create_labeling_template(self, samples: List[Dict],
- output_file: str = 'attorney_labeling_template.txt'):
- """Create attorney-friendly labeling template"""
- filepath = self.output_dir / output_file
-
- with open(filepath, 'w') as f:
- f.write("ATTORNEY LABELING TEMPLATE\n")
- f.write("Jennifer Capasso v. Memorial Sloan Kettering Cancer Center\n")
- f.write("=" * 80 + "\n\n")
-
- f.write("INSTRUCTIONS:\n")
- f.write("For each message below, please provide:\n")
- f.write("1. RESPONSIVE: YES or NO\n")
- f.write("2. REASONING: Brief explanation of your decision\n")
- f.write("3. CRITERIA: Which subpoena criteria matched (1-7):\n")
- f.write(" 1. Treatment at MSK\n")
- f.write(" 2. Complaints to MSK staff\n")
- f.write(" 3. Pronoun/gender marker update requests\n")
- f.write(" 4. Gender markers at other hospitals\n")
- f.write(" 5. Prior discrimination (any setting)\n")
- f.write(" 6. March 7, 2022 surgery\n")
- f.write(" 7. Emotional distress/economic loss\n\n")
- f.write("=" * 80 + "\n\n")
-
- for i, sample in enumerate(samples, 1):
- # Get first message from chunk for labeling
- first_msg = sample['messages'][0] if sample['messages'] else {}
-
- f.write(f"SAMPLE {i}\n")
- f.write("-" * 80 + "\n")
- f.write(f"Line: {first_msg.get('line_number', 'N/A')}\n")
- f.write(f"Time: {first_msg.get('timestamp', 'N/A')}\n")
- f.write(f"Sender: {first_msg.get('sender', 'N/A')}\n")
- f.write(f"Message: {first_msg.get('message', 'N/A')}\n\n")
-
- # Show context (2 messages before and after)
- f.write("Context (surrounding messages):\n")
- for j, msg in enumerate(sample['messages'][:5], 1):
- marker = ">>>" if j == 1 else " "
- f.write(f"{marker} [{msg.get('sender', '?')}]: {msg.get('message', '')[:80]}...\n")
- f.write("\n")
-
- f.write("RESPONSIVE: _______\n")
- f.write("REASONING: _____________________________________________\n")
- f.write("CRITERIA: _______\n")
- f.write("\n" + "=" * 80 + "\n\n")
-
- print(f"\nLabeling template saved: {filepath}")
- print(f"Please have attorney complete this template and save as:")
- print(f" {self.output_dir / 'attorney_labels_completed.txt'}")
-
- return filepath
-
- def save_for_mistral_inference(self, chunks: List[Dict],
- few_shot_file: str = None):
- """
- Save chunks in format ready for Mistral model inference.
- Optionally includes few-shot examples from attorney labels.
- """
- print("\nPreparing data for Mistral inference...")
-
- # Load few-shot examples if provided
- few_shot_prompt = ""
- if few_shot_file and Path(few_shot_file).exists():
- print(f" Loading few-shot examples from: {few_shot_file}")
- # Parse attorney labels (simplified - would need actual parser)
- few_shot_prompt = "\n\nHere are examples of how to classify messages:\n"
- few_shot_prompt += "[Attorney-labeled examples would be inserted here]\n"
-
- # Create inference requests
- inference_requests = []
-
- system_prompt = """You are a legal document review specialist analyzing Signal chat messages for a discrimination lawsuit.
- CASE: Jennifer Capasso v. Memorial Sloan Kettering Cancer Center (MSK)
- CLAIM: Discrimination based on gender identity
- SUBPOENA CRITERIA - Messages are responsive if they relate to:
- 1. Jennifer Capasso's treatment at MSK
- 2. Complaints to MSK staff about Jennifer Capasso
- 3. Requests to update Jennifer Capasso's pronouns/gender markers at MSK
- 4. Gender markers for Jennifer Capasso at other hospitals
- 5. Prior discrimination Jennifer Capasso experienced (any setting)
- 6. Jennifer Capasso's March 7, 2022 surgery at MSK
- 7. Emotional distress/economic loss from MSK treatment
- IMPORTANT: Err on side of OVER-INCLUSION (high recall)."""
- for chunk in chunks:
- messages_text = ""
- for msg in chunk['messages']:
- messages_text += f"Line {msg['line_number']} [{msg['sender']}]: {msg['message']}\n"
-
- prompt = f"""{system_prompt}
- {few_shot_prompt}
- MESSAGES TO REVIEW (Lines {chunk['start_line']}-{chunk['end_line']}):
- {messages_text}
- Respond with JSON:
- {{
- "responsive_line_numbers": [list of responsive line numbers],
- "reasoning": "brief explanation",
- "confidence": "high/medium/low"
- }}"""
- inference_requests.append({
- 'chunk_id': chunk['chunk_id'],
- 'prompt': prompt,
- 'chunk_data': chunk
- })
-
- # Save requests
- requests_file = self.output_dir / 'mistral_inference_requests.jsonl'
- with open(requests_file, 'w') as f:
- for req in inference_requests:
- f.write(json.dumps(req) + '\n')
-
- print(f"Saved {len(inference_requests):,} inference requests to: {requests_file}")
- print("\nNext steps:")
- print("1. Deploy Mixtral 8x22B on Vast.ai (H100 @ $1.33-1.56/hr)")
- print("2. Deploy Mistral 7B on Vast.ai (RTX 4090 @ $0.34/hr)")
- print("3. Run inference on both models")
- print("4. Merge results (take union for high recall)")
-
- return requests_file
- # Example usage
- if __name__ == "__main__":
- # Initialize pipeline
- pipeline = EthicalDiscoveryPipeline('signal_messages.csv')
-
- # Run complete pipeline
- print("\nETHICAL DISCOVERY PIPELINE")
- print("Using only Mistral models (French company, no Trump connections)")
- print("=" * 80)
-
- # Step 1: Load and preprocess
- df = pipeline.load_and_preprocess()
-
- # Step 2: Create chunks
- chunks = pipeline.create_chunks(df, chunk_size=20, overlap=5)
-
- # Step 3: Keyword filter
- keyword_filtered = pipeline.keyword_filter(chunks)
-
- # Step 4: Dual-model semantic filter
- semantic_filtered = pipeline.dual_semantic_filter(
- keyword_filtered,
- threshold1=0.25,
- threshold2=0.25,
- merge_strategy='union' # High recall
- )
-
- # Step 5: Select random samples for attorney
- samples = pipeline.select_random_samples(semantic_filtered, n_samples=20)
-
- # Step 6: Create labeling template
- template_file = pipeline.create_labeling_template(samples)
-
- # Step 7: Prepare for Mistral inference
- requests_file = pipeline.save_for_mistral_inference(semantic_filtered)
-
- print("\n" + "=" * 80)
- print("PIPELINE COMPLETE")
- print("=" * 80)
- print(f"\nReduced from {len(df):,} messages to {len(semantic_filtered):,} chunks")
- print(f"Total reduction: {(1 - len(semantic_filtered)*20/len(df))*100:.1f}%")
- print(f"\nEstimated cost for Mistral inference:")
- print(f" Mixtral 8x22B: {len(semantic_filtered) * 0.5 / 60 * 1.45:.2f} (4-8 hours)")
- print(f" Mistral 7B: {len(semantic_filtered) * 0.3 / 60 * 0.49:.2f} (2-4 hours)")
- print(f" Total: ${(len(semantic_filtered) * 0.5 / 60 * 1.45) + (len(semantic_filtered) * 0.3 / 60 * 0.49):.2f}")
|