diff --git a/.claude/skills/outerbounds/SKILL.md b/.claude/skills/outerbounds/SKILL.md index 6ec5cb1..d8c764b 100644 --- a/.claude/skills/outerbounds/SKILL.md +++ b/.claude/skills/outerbounds/SKILL.md @@ -12,11 +12,12 @@ and a description how it works in [starter-project.md](starter-project.md). You must - Include batch, offline workflows under `flows/`, structured as Metaflow flows. - - Preferably include a `@highlight` card in each flow (see `HighlightTester` in the starter project for example) + - Preferably include a `@highlight` card in each flow in the `end` step + (see `HighlightTester` in the starter project for example) - Include online componets under `deployments/` with a proper configuration. - Define data assets under `data/` and model assets under `models/` - Read [project-assets.md](project-assets.md) for instructions how to define assets - - Include a `@card` for steps that consume and produce assets + - Include a `@card` for steps that consume and produce assets (it must be the only `@card` in the step) - Include a descriptive README.md at the top level, for each `deployment`, `flow`, and asset. ## Defining flows @@ -43,7 +44,8 @@ For instance, python flow/a/flow.py run ``` -Or, if the flow has external dependencies defined with `@pypi` or `@conda`, leverage Fast Bakery on Kubernetes: +Or, if the flow has external dependencies defined with `@pypi`, `@pypi_base`, `@conda`, or `@conda_base`, +leverage Fast Bakery on Kubernetes: ``` python flow/a/flow.py --environment=fast-bakery run --with kubernetes diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2f6d976 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +.metaflow_spin/ +__pycache__/ +*.py[cod] +*.egg-info/ +dist/ +build/ +.eggs/ +*.egg +.venv/ +venv/ +.env +*.so +.mypy_cache/ +.pytest_cache/ +.ruff_cache/ diff --git a/CLAUDE.md b/CLAUDE.md index 40bca50..5fd7efc 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -10,7 +10,8 @@ The project that has the following components: 1. An ETL flow that fetches data from Snowflake - Refer to `example_data.py` for a sample - - Data is processed in batches of at most 1000 rows + - Only include rows with a valid website url + - Data is processed in batches of at most 100 rows - Store the IDs of rows that were processed, next time the flow executes, fetch the next batch - Store the state of processing in an artifact, use Metaflow client to retrieve the state - Include an option for resetting state diff --git a/README.md b/README.md new file mode 100644 index 0000000..5360eb0 --- /dev/null +++ b/README.md @@ -0,0 +1,57 @@ +# Agentic Code Example + +An Outerbounds project that continuously fetches company data from Snowflake, +enriches it with LLM-generated tags by analyzing company websites, and provides +an interactive UI for exploration. + +## Architecture + +``` +SnowflakeETL (hourly) + │ + ├── Fetches batch of 100 companies from Snowflake + ├── Tracks processed IDs across runs + ├── Registers "companies" data asset + └── Publishes "enrich_companies" event + │ + ▼ +CompanyEnricher (event-triggered) + │ + ├── Scrapes company websites in parallel (10 tasks) + ├── Generates 5 tags per company using local LLM + ├── Merges with previous enrichment results + └── Registers "enriched-companies" data asset + │ + ▼ +Company Explorer (deployed UI) + │ + └── Streamlit app to browse companies and tags +``` + +## Components + +| Component | Location | Description | +|-----------|----------|-------------| +| Snowflake ETL | `flows/snowflake-etl/` | Hourly batch ingestion from Snowflake | +| Company Enricher | `flows/company-enricher/` | Website scraping + LLM tagging | +| Company Explorer | `deployments/company-explorer/` | Interactive Streamlit UI | +| Shared Utils | `src/company_utils/` | Snowflake queries, web scraping | + +## Assets + +- **companies** (`data/companies/`) - Raw company data from Snowflake +- **enriched-companies** (`data/enriched-companies/`) - Companies with LLM tags +- **tag-generator** (`models/tag-generator/`) - SmolLM2-1.7B-Instruct model + +## Local Development + +```bash +# Run ETL flow +python flows/snowflake-etl/flow.py run + +# Run enricher (needs GPU or patience on CPU) +python flows/company-enricher/flow.py --environment=fast-bakery run --with kubernetes + +# Reset ETL state +python flows/snowflake-etl/flow.py run --reset yes +``` diff --git a/data/companies/asset_config.toml b/data/companies/asset_config.toml new file mode 100644 index 0000000..5fda198 --- /dev/null +++ b/data/companies/asset_config.toml @@ -0,0 +1,7 @@ +name = "Company Dataset" +id = "companies" +description = "Raw company data fetched from Snowflake in batches" + +[properties] +source = "Snowflake free_company_dataset" +batch_size = "100" diff --git a/data/enriched-companies/asset_config.toml b/data/enriched-companies/asset_config.toml new file mode 100644 index 0000000..6793df7 --- /dev/null +++ b/data/enriched-companies/asset_config.toml @@ -0,0 +1,7 @@ +name = "Enriched Companies" +id = "enriched-companies" +description = "Companies enriched with LLM-generated tags from website analysis" + +[properties] +enrichment = "5 descriptive tags per company from local LLM" +source = "Company websites + LLM inference" diff --git a/deployments/company-explorer/README.md b/deployments/company-explorer/README.md new file mode 100644 index 0000000..25ee9b7 --- /dev/null +++ b/deployments/company-explorer/README.md @@ -0,0 +1,7 @@ +# Company Explorer + +A Streamlit app for browsing companies and their LLM-generated tags. + +- Filter companies by tags or search by name +- View tag distribution across the dataset +- See success/failure status of enrichment diff --git a/deployments/company-explorer/app.py b/deployments/company-explorer/app.py new file mode 100644 index 0000000..5accd4a --- /dev/null +++ b/deployments/company-explorer/app.py @@ -0,0 +1,131 @@ +import streamlit as st +from metaflow import Flow, namespace + +st.set_page_config(page_title="Company Explorer", layout="wide") + + +@st.cache_data(ttl=60) +def load_enriched_companies(): + """Load the latest enriched companies from the CompanyEnricher flow.""" + try: + namespace(None) + run = Flow("CompanyEnricher").latest_successful_run + return run.data.enriched_companies + except Exception as e: + st.error(f"Could not load enriched company data: {e}") + return [] + + +def parse_tag(tag): + """Parse a tag string that may contain multiple numbered tags into individual tags.""" + import re + tag = tag.strip() + # Check if this is a numbered list crammed into one string + numbered = re.split(r"\d+[\.\)]\s*", tag) + numbered = [t.strip().rstrip(",").strip() for t in numbered if t.strip()] + if len(numbered) >= 2: + return numbered + return [tag] if tag else [] + + +def get_all_tags(companies): + """Extract all unique tags across companies.""" + tags = set() + for c in companies: + for t in c.get("tags", []): + for parsed in parse_tag(t): + tags.add(parsed) + return sorted(tags) + + +def main(): + st.title("Company Explorer") + st.markdown("Browse companies and their LLM-generated tags.") + + companies = load_enriched_companies() + + if not companies: + st.warning("No enriched company data available yet. Run the SnowflakeETL and CompanyEnricher flows first.") + return + + # Sidebar filters + all_tags = get_all_tags(companies) + tagged_companies = [c for c in companies if c.get("status") == "success"] + failed_companies = [c for c in companies if c.get("status") != "success"] + + st.sidebar.header("Filters") + selected_tags = st.sidebar.multiselect("Filter by tags", all_tags) + show_failed = st.sidebar.checkbox("Show failed companies", value=False) + search = st.sidebar.text_input("Search by name") + + # Stats + col1, col2, col3 = st.columns(3) + col1.metric("Total Companies", len(companies)) + col2.metric("Successfully Tagged", len(tagged_companies)) + col3.metric("Unique Tags", len(all_tags)) + + st.markdown("---") + + # Filter companies + display = tagged_companies if not show_failed else companies + if selected_tags: + display = [ + c for c in display + if any( + p in selected_tags + for t in c.get("tags", []) + for p in parse_tag(t) + ) + ] + if search: + display = [ + c for c in display if search.lower() in c.get("name", "").lower() + ] + + st.subheader(f"Showing {len(display)} companies") + + # Display as cards in a grid + for i in range(0, len(display), 3): + cols = st.columns(3) + for j, col in enumerate(cols): + idx = i + j + if idx >= len(display): + break + company = display[idx] + with col: + with st.container(border=True): + st.markdown(f"### {company.get('name', 'Unknown')}") + domain = company.get("domain", "") + if domain: + st.markdown(f"[{domain}](https://{domain})") + if company.get("status") == "success": + tags = [ + p for t in company.get("tags", []) for p in parse_tag(t) + ] + tag_html = " ".join( + f'{t}' + for t in tags + ) + st.markdown(tag_html, unsafe_allow_html=True) + else: + st.error(f"Status: {company.get('status', 'unknown')}") + + # Tag cloud + if all_tags: + st.markdown("---") + st.subheader("All Tags") + tag_counts = {} + for c in tagged_companies: + for t in c.get("tags", []): + for p in parse_tag(t): + tag_counts[p] = tag_counts.get(p, 0) + 1 + sorted_tags = sorted(tag_counts.items(), key=lambda x: -x[1]) + tag_html = " ".join( + f'{tag} ({count})' + for tag, count in sorted_tags + ) + st.markdown(tag_html, unsafe_allow_html=True) + + +if __name__ == "__main__": + main() diff --git a/deployments/company-explorer/config.yml b/deployments/company-explorer/config.yml new file mode 100644 index 0000000..a8aa5e4 --- /dev/null +++ b/deployments/company-explorer/config.yml @@ -0,0 +1,15 @@ +name: company-explorer +port: 8000 +description: Interactive UI for exploring companies and their LLM-generated tags + +replicas: + min: 1 + max: 1 + +dependencies: + pypi: + streamlit: "" + outerbounds: "" + +commands: + - streamlit run deployments/company-explorer/app.py --server.port 8000 diff --git a/flows/company-enricher/README.md b/flows/company-enricher/README.md new file mode 100644 index 0000000..4e3f01f --- /dev/null +++ b/flows/company-enricher/README.md @@ -0,0 +1,10 @@ +# Company Enricher + +Enriches company data by scraping each company's website and using a local LLM +to generate 5 descriptive tags. + +- **Trigger**: Automatically triggered when SnowflakeETL finishes (`@trigger_on_finish`) +- **Parallelism**: Processes companies in parallel using foreach (up to 10 tasks) +- **LLM**: Uses SmolLM2-1.7B-Instruct to generate tags from website content +- **Output**: Merges results with previous runs and registers `enriched-companies` data asset +- **Cards**: Each parallel task shows real-time progress; join step shows summary with sample tags diff --git a/flows/company-enricher/flow.py b/flows/company-enricher/flow.py new file mode 100644 index 0000000..c310ac7 --- /dev/null +++ b/flows/company-enricher/flow.py @@ -0,0 +1,316 @@ +from metaflow import ( + card, + step, + current, + resources, + pypi, + retry, +) +from metaflow.cards import Markdown as MD, Table +from metaflow import trigger_on_finish +from obproject import ProjectFlow, highlight + +from company_utils import fetch_landing_page + +MODEL = "HuggingFaceTB/SmolLM2-1.7B-Instruct" + +SYSTEM_PROMPT = ( + "You are a company analyst. Given a company's website content, " + "produce exactly 5 short descriptive tags that capture the company's " + "industry, products, and characteristics. " + "Return ONLY the tags as a comma-separated list, nothing else. " + "Example: enterprise software, cloud computing, B2B, cybersecurity, AI-powered" +) + + +def load_model(): + """Load the LLM and tokenizer once.""" + import torch + from transformers import AutoTokenizer, AutoModelForCausalLM + + device = "cuda" if torch.cuda.is_available() else "cpu" + tokenizer = AutoTokenizer.from_pretrained(MODEL) + model = AutoModelForCausalLM.from_pretrained( + MODEL, torch_dtype=torch.bfloat16 + ).to(device) + return tokenizer, model, device + + +def generate_tags(html_text, tokenizer, model, device): + """Use a local LLM to generate 5 descriptive tags from website content.""" + import torch + + # Truncate HTML to fit context - take first ~2000 chars of visible text + from html.parser import HTMLParser + + class TextExtractor(HTMLParser): + def __init__(self): + super().__init__() + self.parts = [] + self._skip = False + + def handle_starttag(self, tag, attrs): + if tag in ("script", "style", "noscript"): + self._skip = True + + def handle_endtag(self, tag): + if tag in ("script", "style", "noscript"): + self._skip = False + + def handle_data(self, data): + if not self._skip: + text = data.strip() + if text: + self.parts.append(text) + + extractor = TextExtractor() + extractor.feed(html_text) + visible_text = " ".join(extractor.parts)[:2000] + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": f"Here is the website content:\n\n{visible_text}\n\nProvide 5 descriptive tags:", + }, + ] + + input_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + inputs = tokenizer(input_text, return_tensors="pt").to(device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, max_new_tokens=100, temperature=0.3, do_sample=True + ) + + response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True) + tags = parse_tags(response) + return tags + + +def parse_tags(response): + """Parse tags from LLM response, handling numbered lists and comma-separated formats.""" + import re + + text = response.strip() + # Try numbered list format: "1. Tag 2. Tag ..." or "1) Tag 2) Tag ..." + numbered = re.split(r"\d+[\.\)]\s*", text) + numbered = [t.strip().rstrip(",").strip() for t in numbered if t.strip()] + if len(numbered) >= 2: + return numbered[:5] + # Try dash/bullet list + lines = [l.strip().lstrip("-•*").strip() for l in text.splitlines() if l.strip()] + if len(lines) >= 2: + return lines[:5] + # Fallback: comma-separated + tags = [t.strip() for t in text.split(",") if t.strip()] + return tags[:5] + + +@trigger_on_finish(flow="SnowflakeETL") +class CompanyEnricher(ProjectFlow): + + @card(type="blank") + @step + def start(self): + # Get the latest batch from the ETL flow via asset + self.companies_batch = self.prj.get_data("companies") + print(f"Received {len(self.companies_batch)} companies to enrich") + + # Split into chunks for parallel processing (10 tasks) + batch = self.companies_batch + n = max(1, len(batch) // 10) + self.chunks = [batch[i : i + n] for i in range(0, len(batch), n)] + # Cap at 10 chunks + if len(self.chunks) > 10: + self.chunks[-2].extend(self.chunks[-1]) + self.chunks = self.chunks[:10] + + self.next(self.enrich, foreach="chunks") + + @resources(cpu=4, memory=16000) + @retry(times=1) + @card(type="blank", id="enrich_progress", refresh_interval=5) + @pypi( + python="3.11.11", + packages={ + "transformers": "4.55.2", + "torch": "2.8.0", + "requests": "2.32.3", + "pyopenssl": "24.2.1", + "cryptography": "43.0.3", + }, + ) + @step + def enrich(self): + chunk = self.input + results = [] + total = len(chunk) + + current.card["enrich_progress"].append( + MD(f"## Processing {total} companies...") + ) + current.card["enrich_progress"].refresh() + + # Consume the model asset for lineage tracking (best-effort) + try: + self.prj.asset.consume_model_asset("tag-generator") + except Exception: + print("Model asset lineage tracking unavailable") + + # Load the model once for all companies in this chunk + tokenizer, model, device = load_model() + + for i, company in enumerate(chunk): + company_name = company.get("NAME", company.get("name", "Unknown")) + website = company.get("WEBSITE", company.get("website", "")) + company_id = str(company.get("ID", company.get("id", ""))) + + html = fetch_landing_page(website) + if html: + try: + tags = generate_tags(html, tokenizer, model, device) + results.append( + { + "id": company_id, + "name": company_name, + "domain": website, + "tags": tags, + "status": "success", + } + ) + current.card["enrich_progress"].append( + MD(f"✅ [{i+1}/{total}] **{company_name}** — {', '.join(tags)}") + ) + print(f" [{i+1}/{total}] {company_name}: {tags}") + except Exception as e: + results.append( + { + "id": company_id, + "name": company_name, + "domain": website, + "tags": [], + "status": f"llm_error: {e}", + } + ) + current.card["enrich_progress"].append( + MD(f"❌ [{i+1}/{total}] **{company_name}** — LLM error: {e}") + ) + print(f" [{i+1}/{total}] {company_name}: LLM error - {e}") + else: + results.append( + { + "id": company_id, + "name": company_name, + "domain": website, + "tags": [], + "status": "fetch_failed", + } + ) + current.card["enrich_progress"].append( + MD(f"❌ [{i+1}/{total}] **{company_name}** — website fetch failed") + ) + print(f" [{i+1}/{total}] {company_name}: website fetch failed") + + current.card["enrich_progress"].refresh() + + self.chunk_results = results + + # Update card with summary + success = sum(1 for r in results if r["status"] == "success") + failed = total - success + current.card["enrich_progress"].append( + MD(f"\n---\n### Done: ✅ {success} succeeded, ❌ {failed} failed out of {total}") + ) + current.card["enrich_progress"].refresh() + + self.next(self.join) + + @card(type="blank") + @step + def join(self, inputs): + # Merge results from all parallel tasks + all_results = [] + for inp in inputs: + all_results.extend(inp.chunk_results) + + # Fetch previous enrichment results and merge + try: + prev = self.prj.get_data("enriched-companies") + # Build a dict keyed by company id, new results overwrite old + merged = {r["id"]: r for r in prev} + except Exception: + merged = {} + + for r in all_results: + merged[r["id"]] = r + + self.enriched_companies = list(merged.values()) + + # Register the enriched data as an asset + success_count = sum(1 for r in self.enriched_companies if r["status"] == "success") + self.prj.register_data( + "enriched-companies", + "enriched_companies", + annotations={ + "total_companies": str(len(self.enriched_companies)), + "successfully_tagged": str(success_count), + "batch_new": str(len(all_results)), + }, + ) + + # Summary card + failed = [r for r in all_results if r["status"] != "success"] + current.card.append(MD("## Company Enrichment Summary")) + current.card.append( + MD( + f"- **New companies processed**: {len(all_results)}\n" + f"- **Successfully tagged**: {sum(1 for r in all_results if r['status'] == 'success')}\n" + f"- **Failed**: {len(failed)}\n" + f"- **Total enriched companies**: {len(self.enriched_companies)}" + ) + ) + + # Show sample tags + tagged = [r for r in all_results if r["status"] == "success"][:10] + if tagged: + rows = [[r["name"], r["domain"], ", ".join(r["tags"])] for r in tagged] + current.card.append(MD("### Sample Tags")) + current.card.append( + Table( + headers=["Company", "Domain", "Tags"], + data=rows, + ) + ) + + if failed: + current.card.append(MD("### Failed Companies")) + fail_rows = [[r["name"], r["domain"], r["status"]] for r in failed[:10]] + current.card.append( + Table( + headers=["Company", "Domain", "Error"], + data=fail_rows, + ) + ) + + self.new_count = len(all_results) + self.success_count = success_count + + self.next(self.end) + + @highlight + @step + def end(self): + self.highlight.title = "Company Enricher" + self.highlight.add_column( + big=str(self.new_count), small="processed" + ) + self.highlight.add_column( + big=str(self.success_count), small="total tagged" + ) + + +if __name__ == "__main__": + CompanyEnricher() diff --git a/flows/snowflake-etl/README.md b/flows/snowflake-etl/README.md new file mode 100644 index 0000000..d99904c --- /dev/null +++ b/flows/snowflake-etl/README.md @@ -0,0 +1,11 @@ +# Snowflake ETL + +Fetches company data from Snowflake in batches of 100 rows. Tracks which rows +have been processed across runs using Metaflow artifacts, so each execution +picks up where the last one left off. + +- **Schedule**: Hourly +- **State**: Processed row IDs stored as an artifact, retrieved via Metaflow client +- **Reset**: Pass `--reset yes` to start from scratch +- **Output**: Triggers CompanyEnricher flow on completion via `@trigger_on_finish` +- **Asset**: Registers fetched batch as `companies` data asset diff --git a/flows/snowflake-etl/flow.py b/flows/snowflake-etl/flow.py new file mode 100644 index 0000000..ed4715d --- /dev/null +++ b/flows/snowflake-etl/flow.py @@ -0,0 +1,91 @@ +from metaflow import ( + card, + step, + current, + Flow, + schedule, + Parameter, + pypi_base, +) +from metaflow.cards import Markdown as MD +from obproject import ProjectFlow, highlight + +from company_utils import fetch_company_batch + + +@pypi_base(python="3.11.11", packages={ + "snowflake-connector-python[pandas]": "3.12.3", + "pyopenssl": "24.2.1", + "cryptography": "43.0.3", +}) +@schedule(hourly=True) +class SnowflakeETL(ProjectFlow): + + reset = Parameter( + "reset", + default="no", + help="Set to 'yes' to reset processing state and start from scratch", + ) + + @card(type="blank") + @step + def start(self): + # Retrieve previously processed IDs from the last successful run + processed_ids = set() + if self.reset == "no": + try: + run = Flow(current.flow_name).latest_successful_run + processed_ids = set(run.data.processed_ids) + print(f"Resuming from previous state: {len(processed_ids)} rows already processed") + except Exception: + print("No previous state found - starting from scratch") + else: + print("Reset requested - starting from scratch") + + # Fetch next batch from Snowflake + df = fetch_company_batch(processed_ids if processed_ids else None) + batch_size = len(df) + print(f"Fetched batch of {batch_size} rows from Snowflake") + + if batch_size == 0: + current.card.append(MD("## No new rows to process")) + self.batch = [] + self.processed_ids = list(processed_ids) + self.next(self.end) + return + + # Convert batch to list of dicts for downstream processing + self.batch = df.to_dict(orient="records") + new_ids = set(str(row.get("ID", row.get("id", ""))) for row in self.batch) + self.processed_ids = list(processed_ids | new_ids) + + # Register the company data as an asset + self.companies = self.batch + self.prj.register_data( + "companies", + "companies", + annotations={"batch_size": str(batch_size), "total_processed": str(len(self.processed_ids))}, + ) + + # Update card + current.card.append(MD(f"## Snowflake ETL Batch")) + current.card.append(MD(f"- **New rows fetched**: {batch_size}")) + current.card.append(MD(f"- **Total rows processed**: {len(self.processed_ids)}")) + if self.batch: + sample = self.batch[0] + cols = list(sample.keys())[:5] + current.card.append(MD(f"- **Sample columns**: {', '.join(cols)}")) + + self.next(self.end) + + @highlight + @step + def end(self): + batch_size = len(self.batch) + self.highlight.title = "Snowflake ETL" + self.highlight.add_column(big=str(batch_size), small="new rows") + self.highlight.add_column(big=str(len(self.processed_ids)), small="total processed") + + +if __name__ == "__main__": + SnowflakeETL() diff --git a/models/tag-generator/asset_config.toml b/models/tag-generator/asset_config.toml new file mode 100644 index 0000000..1b74c32 --- /dev/null +++ b/models/tag-generator/asset_config.toml @@ -0,0 +1,9 @@ +name = "Tag Generator LLM" +id = "tag-generator" +description = "Small local LLM used to generate descriptive tags for companies" +blobs = ["HuggingFaceTB/SmolLM2-1.7B-Instruct"] + +[properties] +provider = "HuggingFaceTB/SmolLM2-1.7B-Instruct" +task = "text-generation" +output = "5 descriptive tags per company" diff --git a/obproject.toml b/obproject.toml index 0641949..f644ac9 100644 --- a/obproject.toml +++ b/obproject.toml @@ -2,6 +2,3 @@ platform = 'dev-yellow.outerbounds.xyz' project = 'ob_agentic_code_example' title = 'Agentic Code Example' - -[dev-assets] -branch = 'dev' diff --git a/src/company_utils/__init__.py b/src/company_utils/__init__.py new file mode 100644 index 0000000..70fdfff --- /dev/null +++ b/src/company_utils/__init__.py @@ -0,0 +1,4 @@ +from .snowflake import fetch_company_batch +from .scraper import fetch_landing_page + +METAFLOW_PACKAGE_POLICY = "include" diff --git a/src/company_utils/scraper.py b/src/company_utils/scraper.py new file mode 100644 index 0000000..931cbc7 --- /dev/null +++ b/src/company_utils/scraper.py @@ -0,0 +1,24 @@ +import requests + + +def fetch_landing_page(url, timeout=10): + """ + Fetch the landing page HTML of a company website. + Returns the text content or None on failure. + """ + if not url: + return None + if not url.startswith("http"): + url = "https://" + url + try: + resp = requests.get( + url, + timeout=timeout, + headers={"User-Agent": "Mozilla/5.0 (compatible; CompanyEnricher/1.0)"}, + allow_redirects=True, + ) + resp.raise_for_status() + return resp.text + except Exception as e: + print(f"Failed to fetch {url}: {e}") + return None diff --git a/src/company_utils/snowflake.py b/src/company_utils/snowflake.py new file mode 100644 index 0000000..1b77901 --- /dev/null +++ b/src/company_utils/snowflake.py @@ -0,0 +1,36 @@ +from metaflow import Snowflake + +QUERY_BATCH = """ +SELECT * +FROM free_company_dataset.public.freecompanydataset +WHERE WEBSITE IS NOT NULL AND TRIM(WEBSITE) != '' + AND id NOT IN ({excluded}) +LIMIT 100 +""" + +QUERY_BATCH_NO_EXCLUDE = """ +SELECT * +FROM free_company_dataset.public.freecompanydataset +WHERE WEBSITE IS NOT NULL AND TRIM(WEBSITE) != '' +LIMIT 100 +""" + + +def fetch_company_batch(processed_ids=None): + """ + Fetch a batch of up to 100 company rows from Snowflake, + skipping any IDs in processed_ids. Only includes rows + with a valid website URL (non-empty DOMAIN). + + Returns a pandas DataFrame. + """ + with Snowflake(integration="snowflake-test") as cn: + cursor = cn.cursor() + if processed_ids: + placeholders = ",".join(f"'{pid}'" for pid in processed_ids) + query = QUERY_BATCH.format(excluded=placeholders) + else: + query = QUERY_BATCH_NO_EXCLUDE + cursor.execute(query) + df = cursor.fetch_pandas_all() + return df