509 lines
18 KiB
Python
509 lines
18 KiB
Python
"""FastAPI entry point for the sFetch backend."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from collections.abc import AsyncIterator
|
|
from datetime import UTC, datetime
|
|
|
|
from fastapi import FastAPI, HTTPException, Query, BackgroundTasks
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from crawler import sFetchBot
|
|
from config import TOP_SITE_SEED_LIMIT, TOP_SITE_SEED_META_KEY
|
|
from database import (
|
|
count_image_results,
|
|
count_search_results,
|
|
count_video_results,
|
|
get_meta_value,
|
|
get_stats,
|
|
init_db,
|
|
set_meta_value,
|
|
)
|
|
from models import AIAnswerResponse, AIChatRequest, AISearchRequest, AISource, CrawlRequest, SearchResponse
|
|
from ollama_cloud import (
|
|
OllamaCloudError,
|
|
chat as ollama_chat,
|
|
default_model,
|
|
is_ollama_configured,
|
|
list_models as list_ollama_models,
|
|
stream_chat as ollama_stream_chat,
|
|
web_search as ollama_web_search,
|
|
)
|
|
from searcher import search, search_images_api, search_videos_api
|
|
from top_sites import load_top_site_seed_urls
|
|
|
|
app = FastAPI(title="sFetch API", version="1.0.0")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=False,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
def _utc_now() -> str:
|
|
return datetime.now(UTC).isoformat()
|
|
|
|
|
|
def _set_seed_status(**updates: object) -> None:
|
|
current = getattr(app.state, "_top_scrape_status", {}).copy()
|
|
current.update({"updated_at": _utc_now(), **updates})
|
|
app.state._top_scrape_status = current
|
|
|
|
|
|
async def _scrape_top_sites(force: bool = False) -> None:
|
|
await init_db()
|
|
|
|
async with app.state._crawl_lock:
|
|
if app.state._top_scrape_done and not force:
|
|
return
|
|
|
|
existing_seed = await get_meta_value(TOP_SITE_SEED_META_KEY)
|
|
if existing_seed and not force:
|
|
stats = await get_stats()
|
|
_set_seed_status(
|
|
state="stored",
|
|
message="Top-site seed already stored in the database.",
|
|
total=TOP_SITE_SEED_LIMIT,
|
|
indexed=stats["total_pages"],
|
|
source=existing_seed,
|
|
)
|
|
app.state._top_scrape_done = True
|
|
return
|
|
|
|
stats = await get_stats()
|
|
if stats["total_pages"] >= TOP_SITE_SEED_LIMIT and not force:
|
|
source = "existing database"
|
|
await set_meta_value(TOP_SITE_SEED_META_KEY, source)
|
|
_set_seed_status(
|
|
state="stored",
|
|
message="Top-site seed already stored in the database.",
|
|
total=TOP_SITE_SEED_LIMIT,
|
|
indexed=stats["total_pages"],
|
|
source=source,
|
|
)
|
|
app.state._top_scrape_done = True
|
|
return
|
|
|
|
_set_seed_status(state="loading", message="Loading top-site list.", total=TOP_SITE_SEED_LIMIT, indexed=0)
|
|
seed_urls, source = await load_top_site_seed_urls(limit=TOP_SITE_SEED_LIMIT)
|
|
_set_seed_status(
|
|
state="running",
|
|
message=f"Seeding {len(seed_urls)} non-adult top sites.",
|
|
total=len(seed_urls),
|
|
indexed=0,
|
|
source=source,
|
|
)
|
|
|
|
print(f"sFetch: seeding index with {len(seed_urls)} non-adult top sites from {source}...")
|
|
bot = sFetchBot(max_depth=0, same_domain_only=True, max_pages_per_domain=1, max_concurrency=12)
|
|
try:
|
|
await bot.start(seed_urls)
|
|
except Exception as exc:
|
|
_set_seed_status(state="error", message=f"Top-site seed failed: {exc}", indexed=bot.indexed_count)
|
|
print(f"sFetch: top-site seed failed ({exc})")
|
|
return
|
|
|
|
await set_meta_value(TOP_SITE_SEED_META_KEY, source)
|
|
_set_seed_status(
|
|
state="complete",
|
|
message="Top-site seed complete.",
|
|
total=len(seed_urls),
|
|
indexed=bot.indexed_count,
|
|
source=source,
|
|
)
|
|
print(f"sFetch: seeding complete. {bot.indexed_count} pages indexed.")
|
|
app.state._top_scrape_done = True
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event() -> None:
|
|
app.state._top_scrape_done = False
|
|
app.state._crawl_lock = asyncio.Lock()
|
|
app.state._top_scrape_status = {
|
|
"state": "idle",
|
|
"message": "Waiting to check top-site seed.",
|
|
"total": TOP_SITE_SEED_LIMIT,
|
|
"indexed": 0,
|
|
"source": None,
|
|
"updated_at": _utc_now(),
|
|
}
|
|
asyncio.create_task(_scrape_top_sites())
|
|
|
|
|
|
@app.get("/")
|
|
async def health_check() -> dict[str, str]:
|
|
return {"status": "sFetch is alive"}
|
|
|
|
|
|
@app.get("/search", response_model=SearchResponse)
|
|
async def search_endpoint(
|
|
q: str = Query(..., description="Search query"),
|
|
type: str = Query("web", description="Search type: web, image, or video"),
|
|
limit: int = Query(10, ge=1, le=50),
|
|
offset: int = Query(0, ge=0),
|
|
) -> SearchResponse:
|
|
query = q.strip()
|
|
if not query:
|
|
raise HTTPException(status_code=400, detail="Query parameter 'q' cannot be empty.")
|
|
|
|
if type == "image":
|
|
results = await search_images_api(query=query, limit=limit, offset=offset)
|
|
total = await count_image_results(query)
|
|
return SearchResponse(query=query, type=type, total=total, results=results)
|
|
|
|
if type == "video":
|
|
results = await search_videos_api(query=query, limit=limit, offset=offset)
|
|
total = await count_video_results(query)
|
|
return SearchResponse(query=query, type=type, total=total, results=results)
|
|
|
|
if type != "web":
|
|
raise HTTPException(status_code=400, detail="Invalid search type. Use web, image, or video.")
|
|
|
|
results = await search(query=query, limit=limit, offset=offset)
|
|
total = await count_search_results(query)
|
|
return SearchResponse(query=query, type=type, total=total, results=results)
|
|
|
|
|
|
async def _run_crawl_job(request: CrawlRequest) -> None:
|
|
try:
|
|
bot = sFetchBot(
|
|
max_depth=request.max_depth,
|
|
max_pages_per_domain=request.max_pages_per_domain,
|
|
same_domain_only=request.same_domain_only,
|
|
)
|
|
await bot.start(request.seed_urls)
|
|
except Exception as exc:
|
|
print(f"sFetch: crawl job failed ({exc})")
|
|
|
|
|
|
@app.post("/crawl")
|
|
async def crawl_endpoint(request: CrawlRequest, background_tasks: BackgroundTasks) -> dict[str, object]:
|
|
background_tasks.add_task(_run_crawl_job, request)
|
|
return {"message": "Crawl started", "seed_urls": request.seed_urls}
|
|
|
|
|
|
@app.post("/crawl/top-sites")
|
|
async def crawl_top_sites_endpoint(
|
|
background_tasks: BackgroundTasks,
|
|
force: bool = Query(False, description="Run the top-site seed again even if it is marked complete."),
|
|
) -> dict[str, object]:
|
|
background_tasks.add_task(_scrape_top_sites, force)
|
|
return {"message": "Top-site crawl queued", "force": force}
|
|
|
|
|
|
@app.get("/crawl/top-sites/status")
|
|
async def crawl_top_sites_status_endpoint() -> dict[str, object]:
|
|
return getattr(
|
|
app.state,
|
|
"_top_scrape_status",
|
|
{
|
|
"state": "idle",
|
|
"message": "Top-site seed has not started.",
|
|
"total": TOP_SITE_SEED_LIMIT,
|
|
"indexed": 0,
|
|
"source": None,
|
|
"updated_at": None,
|
|
},
|
|
)
|
|
|
|
|
|
@app.get("/stats")
|
|
async def stats_endpoint() -> dict[str, object]:
|
|
stats = await get_stats()
|
|
return stats
|
|
|
|
|
|
@app.get("/ai/config")
|
|
async def ai_config_endpoint() -> dict[str, object]:
|
|
return {
|
|
"configured": is_ollama_configured(),
|
|
"default_model": default_model(),
|
|
"provider": "Ollama Cloud",
|
|
}
|
|
|
|
|
|
@app.get("/ai/models")
|
|
async def ai_models_endpoint() -> dict[str, object]:
|
|
try:
|
|
models = await list_ollama_models()
|
|
except OllamaCloudError as exc:
|
|
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
|
return {
|
|
"default_model": default_model(),
|
|
"models": models,
|
|
}
|
|
|
|
|
|
@app.post("/ai/chat", response_model=AIAnswerResponse)
|
|
async def ai_chat_endpoint(request: AIChatRequest) -> AIAnswerResponse:
|
|
model = (request.model or default_model()).strip()
|
|
if not model:
|
|
raise HTTPException(status_code=400, detail="Model is required.")
|
|
if not request.messages:
|
|
raise HTTPException(status_code=400, detail="At least one message is required.")
|
|
|
|
try:
|
|
messages, sources = await _build_chat_messages_and_sources(request)
|
|
response = await ollama_chat(model=model, messages=messages, think=request.think)
|
|
except OllamaCloudError as exc:
|
|
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
|
|
|
message = response.get("message") or {}
|
|
return AIAnswerResponse(
|
|
model=response.get("model") or model,
|
|
content=message.get("content") or "",
|
|
thinking=message.get("thinking"),
|
|
sources=sources,
|
|
configured=is_ollama_configured(),
|
|
)
|
|
|
|
|
|
def _sse(event: str, data: object) -> str:
|
|
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
async def _build_chat_messages_and_sources(request: AIChatRequest) -> tuple[list[dict[str, object]], list[AISource]]:
|
|
messages = [
|
|
message.model_dump(exclude_none=True)
|
|
for message in request.messages
|
|
if message.content.strip() or message.tool_calls
|
|
]
|
|
if not messages:
|
|
raise OllamaCloudError("At least one message is required.", status_code=400)
|
|
|
|
sources: list[AISource] = []
|
|
if request.use_web_search:
|
|
latest_user_message = next(
|
|
(message.content for message in reversed(request.messages) if message.role == "user" and message.content.strip()),
|
|
"",
|
|
)
|
|
if latest_user_message:
|
|
web_results = await ollama_web_search(latest_user_message, max_results=request.web_result_limit)
|
|
sources = [
|
|
AISource(
|
|
title=result.get("title") or result.get("url") or "Web result",
|
|
url=result.get("url") or "",
|
|
source_type="web",
|
|
content=result.get("content") or "",
|
|
)
|
|
for result in web_results
|
|
if result.get("url")
|
|
]
|
|
if sources:
|
|
context = "\n".join(_source_text(source, index) for index, source in enumerate(sources, start=1))
|
|
messages.insert(
|
|
0,
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"Use the following web search context when it is relevant. "
|
|
"Cite sources inline using bracket numbers such as [1].\n\n"
|
|
f"{context}"
|
|
),
|
|
},
|
|
)
|
|
|
|
return messages, sources
|
|
|
|
|
|
async def _stream_ollama_events(
|
|
model: str,
|
|
messages: list[dict[str, object]],
|
|
think: bool | str | None,
|
|
sources: list[AISource],
|
|
) -> AsyncIterator[str]:
|
|
content = ""
|
|
thinking = ""
|
|
yield _sse(
|
|
"meta",
|
|
{
|
|
"model": model,
|
|
"configured": is_ollama_configured(),
|
|
"sources": [source.model_dump() for source in sources],
|
|
},
|
|
)
|
|
|
|
try:
|
|
async for chunk in ollama_stream_chat(model=model, messages=messages, think=think):
|
|
message = chunk.get("message") or {}
|
|
thinking_delta = message.get("thinking") or ""
|
|
content_delta = message.get("content") or ""
|
|
|
|
if thinking_delta:
|
|
thinking += thinking_delta
|
|
yield _sse("thinking", {"delta": thinking_delta})
|
|
|
|
if content_delta:
|
|
content += content_delta
|
|
yield _sse("content", {"delta": content_delta})
|
|
|
|
if chunk.get("done"):
|
|
yield _sse(
|
|
"done",
|
|
{
|
|
"model": chunk.get("model") or model,
|
|
"content": content,
|
|
"thinking": thinking,
|
|
"sources": [source.model_dump() for source in sources],
|
|
},
|
|
)
|
|
return
|
|
except OllamaCloudError as exc:
|
|
yield _sse("error", {"detail": str(exc), "status_code": exc.status_code})
|
|
except Exception as exc:
|
|
yield _sse("error", {"detail": f"Streaming failed: {exc}", "status_code": 502})
|
|
|
|
|
|
@app.post("/ai/chat/stream")
|
|
async def ai_chat_stream_endpoint(request: AIChatRequest) -> StreamingResponse:
|
|
model = (request.model or default_model()).strip()
|
|
if not model:
|
|
raise HTTPException(status_code=400, detail="Model is required.")
|
|
if not is_ollama_configured():
|
|
raise HTTPException(status_code=503, detail="Ollama Cloud is not configured. Set OLLAMA_API_KEY.")
|
|
|
|
try:
|
|
messages, sources = await _build_chat_messages_and_sources(request)
|
|
except OllamaCloudError as exc:
|
|
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
|
|
|
return StreamingResponse(
|
|
_stream_ollama_events(model=model, messages=messages, think=request.think, sources=sources),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
|
|
def _source_text(source: AISource, index: int) -> str:
|
|
return (
|
|
f"[{index}] {source.title}\n"
|
|
f"Type: {source.source_type}\n"
|
|
f"URL: {source.url}\n"
|
|
f"Excerpt: {source.content[:1200]}\n"
|
|
)
|
|
|
|
|
|
async def _build_ai_search_sources(request: AISearchRequest) -> list[AISource]:
|
|
local_results = await search(query=request.query, limit=request.local_result_limit, offset=0)
|
|
sources = [
|
|
AISource(
|
|
title=result["title"],
|
|
url=result["url"],
|
|
source_type="local",
|
|
content=result["snippet"],
|
|
)
|
|
for result in local_results
|
|
]
|
|
|
|
if request.include_web:
|
|
web_results = await ollama_web_search(request.query, max_results=request.web_result_limit)
|
|
sources.extend(
|
|
AISource(
|
|
title=result.get("title") or result.get("url") or "Web result",
|
|
url=result.get("url") or "",
|
|
source_type="web",
|
|
content=result.get("content") or "",
|
|
)
|
|
for result in web_results
|
|
if result.get("url")
|
|
)
|
|
|
|
return sources
|
|
|
|
|
|
@app.post("/ai/search", response_model=AIAnswerResponse)
|
|
async def ai_search_endpoint(request: AISearchRequest) -> AIAnswerResponse:
|
|
model = (request.model or default_model()).strip()
|
|
query = request.query.strip()
|
|
if not model:
|
|
raise HTTPException(status_code=400, detail="Model is required.")
|
|
if not query:
|
|
raise HTTPException(status_code=400, detail="Query is required.")
|
|
|
|
try:
|
|
sources = await _build_ai_search_sources(request)
|
|
source_context = "\n".join(_source_text(source, index) for index, source in enumerate(sources, start=1))
|
|
if not source_context:
|
|
source_context = "No search sources were found for this query."
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"You are sFetch AI, a precise search assistant. Answer only from the provided sources. "
|
|
"Write in a neutral, professional tone. Keep the response concise. "
|
|
"Cite sources inline using bracket numbers such as [1]. "
|
|
"If the sources are insufficient, say what is missing rather than guessing."
|
|
),
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": f"Search query: {query}\n\nSources:\n{source_context}",
|
|
},
|
|
]
|
|
response = await ollama_chat(model=model, messages=messages, think=request.think)
|
|
except OllamaCloudError as exc:
|
|
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
|
|
|
message = response.get("message") or {}
|
|
return AIAnswerResponse(
|
|
model=response.get("model") or model,
|
|
content=message.get("content") or "",
|
|
thinking=message.get("thinking"),
|
|
sources=sources,
|
|
configured=is_ollama_configured(),
|
|
)
|
|
|
|
|
|
async def _build_ai_search_messages(request: AISearchRequest) -> tuple[list[dict[str, str]], list[AISource]]:
|
|
sources = await _build_ai_search_sources(request)
|
|
source_context = "\n".join(_source_text(source, index) for index, source in enumerate(sources, start=1))
|
|
if not source_context:
|
|
source_context = "No search sources were found for this query."
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"You are sFetch AI, a precise search assistant. Answer only from the provided sources. "
|
|
"Write in a neutral, useful tone with direct synthesis. "
|
|
"Cite sources inline using bracket numbers such as [1]. "
|
|
"If sources are insufficient, say what is missing rather than guessing."
|
|
),
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": f"Search query: {request.query.strip()}\n\nSources:\n{source_context}",
|
|
},
|
|
]
|
|
return messages, sources
|
|
|
|
|
|
@app.post("/ai/search/stream")
|
|
async def ai_search_stream_endpoint(request: AISearchRequest) -> StreamingResponse:
|
|
model = (request.model or default_model()).strip()
|
|
query = request.query.strip()
|
|
if not model:
|
|
raise HTTPException(status_code=400, detail="Model is required.")
|
|
if not query:
|
|
raise HTTPException(status_code=400, detail="Query is required.")
|
|
if not is_ollama_configured():
|
|
raise HTTPException(status_code=503, detail="Ollama Cloud is not configured. Set OLLAMA_API_KEY.")
|
|
|
|
try:
|
|
messages, sources = await _build_ai_search_messages(request)
|
|
except OllamaCloudError as exc:
|
|
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
|
|
|
return StreamingResponse(
|
|
_stream_ollama_events(model=model, messages=messages, think=request.think, sources=sources),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|