Add AI functionality; fuck up UI royally, still a piece of shit.
This commit is contained in:
@@ -1,9 +1,18 @@
|
||||
"""Application-wide configuration for sFetch."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(BASE_DIR.parent / ".env")
|
||||
load_dotenv(BASE_DIR / ".env")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
MAX_CRAWL_DEPTH = 2
|
||||
MAX_PAGES_PER_DOMAIN = 50
|
||||
CRAWL_DELAY_SECONDS = 1.0
|
||||
@@ -14,6 +23,10 @@ TOP_SITE_SOURCE_URL = "https://tranco-list.eu/top-1m.csv.zip"
|
||||
TOP_SITE_SEED_LIMIT = 1000
|
||||
TOP_SITE_DOWNLOAD_TIMEOUT_SECONDS = 30.0
|
||||
TOP_SITE_SEED_META_KEY = "top_site_seed_v1"
|
||||
OLLAMA_API_BASE = os.getenv("OLLAMA_API_BASE", "https://ollama.com").rstrip("/")
|
||||
OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY", "")
|
||||
OLLAMA_DEFAULT_MODEL = os.getenv("OLLAMA_DEFAULT_MODEL", "gpt-oss:120b")
|
||||
OLLAMA_REQUEST_TIMEOUT_SECONDS = 90.0
|
||||
|
||||
ADULT_DOMAINS = {
|
||||
"pornhub.com", "xvideos.com", "xnxx.com", "xhamster.com", "redtube.com",
|
||||
|
||||
+302
-1
@@ -3,10 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query, BackgroundTasks
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from crawler import sFetchBot
|
||||
from config import TOP_SITE_SEED_LIMIT, TOP_SITE_SEED_META_KEY
|
||||
@@ -19,7 +22,16 @@ from database import (
|
||||
init_db,
|
||||
set_meta_value,
|
||||
)
|
||||
from models import CrawlRequest, SearchResponse
|
||||
from models import AIAnswerResponse, AIChatRequest, AISearchRequest, AISource, CrawlRequest, SearchResponse
|
||||
from ollama_cloud import (
|
||||
OllamaCloudError,
|
||||
chat as ollama_chat,
|
||||
default_model,
|
||||
is_ollama_configured,
|
||||
list_models as list_ollama_models,
|
||||
stream_chat as ollama_stream_chat,
|
||||
web_search as ollama_web_search,
|
||||
)
|
||||
from searcher import search, search_images_api, search_videos_api
|
||||
from top_sites import load_top_site_seed_urls
|
||||
|
||||
@@ -205,3 +217,292 @@ async def crawl_top_sites_status_endpoint() -> dict[str, object]:
|
||||
async def stats_endpoint() -> dict[str, object]:
|
||||
stats = await get_stats()
|
||||
return stats
|
||||
|
||||
|
||||
@app.get("/ai/config")
|
||||
async def ai_config_endpoint() -> dict[str, object]:
|
||||
return {
|
||||
"configured": is_ollama_configured(),
|
||||
"default_model": default_model(),
|
||||
"provider": "Ollama Cloud",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/ai/models")
|
||||
async def ai_models_endpoint() -> dict[str, object]:
|
||||
try:
|
||||
models = await list_ollama_models()
|
||||
except OllamaCloudError as exc:
|
||||
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
||||
return {
|
||||
"default_model": default_model(),
|
||||
"models": models,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/ai/chat", response_model=AIAnswerResponse)
|
||||
async def ai_chat_endpoint(request: AIChatRequest) -> AIAnswerResponse:
|
||||
model = (request.model or default_model()).strip()
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Model is required.")
|
||||
if not request.messages:
|
||||
raise HTTPException(status_code=400, detail="At least one message is required.")
|
||||
|
||||
try:
|
||||
messages, sources = await _build_chat_messages_and_sources(request)
|
||||
response = await ollama_chat(model=model, messages=messages, think=request.think)
|
||||
except OllamaCloudError as exc:
|
||||
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
||||
|
||||
message = response.get("message") or {}
|
||||
return AIAnswerResponse(
|
||||
model=response.get("model") or model,
|
||||
content=message.get("content") or "",
|
||||
thinking=message.get("thinking"),
|
||||
sources=sources,
|
||||
configured=is_ollama_configured(),
|
||||
)
|
||||
|
||||
|
||||
def _sse(event: str, data: object) -> str:
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
async def _build_chat_messages_and_sources(request: AIChatRequest) -> tuple[list[dict[str, object]], list[AISource]]:
|
||||
messages = [
|
||||
message.model_dump(exclude_none=True)
|
||||
for message in request.messages
|
||||
if message.content.strip() or message.tool_calls
|
||||
]
|
||||
if not messages:
|
||||
raise OllamaCloudError("At least one message is required.", status_code=400)
|
||||
|
||||
sources: list[AISource] = []
|
||||
if request.use_web_search:
|
||||
latest_user_message = next(
|
||||
(message.content for message in reversed(request.messages) if message.role == "user" and message.content.strip()),
|
||||
"",
|
||||
)
|
||||
if latest_user_message:
|
||||
web_results = await ollama_web_search(latest_user_message, max_results=request.web_result_limit)
|
||||
sources = [
|
||||
AISource(
|
||||
title=result.get("title") or result.get("url") or "Web result",
|
||||
url=result.get("url") or "",
|
||||
source_type="web",
|
||||
content=result.get("content") or "",
|
||||
)
|
||||
for result in web_results
|
||||
if result.get("url")
|
||||
]
|
||||
if sources:
|
||||
context = "\n".join(_source_text(source, index) for index, source in enumerate(sources, start=1))
|
||||
messages.insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Use the following web search context when it is relevant. "
|
||||
"Cite sources inline using bracket numbers such as [1].\n\n"
|
||||
f"{context}"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return messages, sources
|
||||
|
||||
|
||||
async def _stream_ollama_events(
|
||||
model: str,
|
||||
messages: list[dict[str, object]],
|
||||
think: bool | str | None,
|
||||
sources: list[AISource],
|
||||
) -> AsyncIterator[str]:
|
||||
content = ""
|
||||
thinking = ""
|
||||
yield _sse(
|
||||
"meta",
|
||||
{
|
||||
"model": model,
|
||||
"configured": is_ollama_configured(),
|
||||
"sources": [source.model_dump() for source in sources],
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
async for chunk in ollama_stream_chat(model=model, messages=messages, think=think):
|
||||
message = chunk.get("message") or {}
|
||||
thinking_delta = message.get("thinking") or ""
|
||||
content_delta = message.get("content") or ""
|
||||
|
||||
if thinking_delta:
|
||||
thinking += thinking_delta
|
||||
yield _sse("thinking", {"delta": thinking_delta})
|
||||
|
||||
if content_delta:
|
||||
content += content_delta
|
||||
yield _sse("content", {"delta": content_delta})
|
||||
|
||||
if chunk.get("done"):
|
||||
yield _sse(
|
||||
"done",
|
||||
{
|
||||
"model": chunk.get("model") or model,
|
||||
"content": content,
|
||||
"thinking": thinking,
|
||||
"sources": [source.model_dump() for source in sources],
|
||||
},
|
||||
)
|
||||
return
|
||||
except OllamaCloudError as exc:
|
||||
yield _sse("error", {"detail": str(exc), "status_code": exc.status_code})
|
||||
except Exception as exc:
|
||||
yield _sse("error", {"detail": f"Streaming failed: {exc}", "status_code": 502})
|
||||
|
||||
|
||||
@app.post("/ai/chat/stream")
|
||||
async def ai_chat_stream_endpoint(request: AIChatRequest) -> StreamingResponse:
|
||||
model = (request.model or default_model()).strip()
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Model is required.")
|
||||
if not is_ollama_configured():
|
||||
raise HTTPException(status_code=503, detail="Ollama Cloud is not configured. Set OLLAMA_API_KEY.")
|
||||
|
||||
try:
|
||||
messages, sources = await _build_chat_messages_and_sources(request)
|
||||
except OllamaCloudError as exc:
|
||||
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
||||
|
||||
return StreamingResponse(
|
||||
_stream_ollama_events(model=model, messages=messages, think=request.think, sources=sources),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
def _source_text(source: AISource, index: int) -> str:
|
||||
return (
|
||||
f"[{index}] {source.title}\n"
|
||||
f"Type: {source.source_type}\n"
|
||||
f"URL: {source.url}\n"
|
||||
f"Excerpt: {source.content[:1200]}\n"
|
||||
)
|
||||
|
||||
|
||||
async def _build_ai_search_sources(request: AISearchRequest) -> list[AISource]:
|
||||
local_results = await search(query=request.query, limit=request.local_result_limit, offset=0)
|
||||
sources = [
|
||||
AISource(
|
||||
title=result["title"],
|
||||
url=result["url"],
|
||||
source_type="local",
|
||||
content=result["snippet"],
|
||||
)
|
||||
for result in local_results
|
||||
]
|
||||
|
||||
if request.include_web:
|
||||
web_results = await ollama_web_search(request.query, max_results=request.web_result_limit)
|
||||
sources.extend(
|
||||
AISource(
|
||||
title=result.get("title") or result.get("url") or "Web result",
|
||||
url=result.get("url") or "",
|
||||
source_type="web",
|
||||
content=result.get("content") or "",
|
||||
)
|
||||
for result in web_results
|
||||
if result.get("url")
|
||||
)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
@app.post("/ai/search", response_model=AIAnswerResponse)
|
||||
async def ai_search_endpoint(request: AISearchRequest) -> AIAnswerResponse:
|
||||
model = (request.model or default_model()).strip()
|
||||
query = request.query.strip()
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Model is required.")
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="Query is required.")
|
||||
|
||||
try:
|
||||
sources = await _build_ai_search_sources(request)
|
||||
source_context = "\n".join(_source_text(source, index) for index, source in enumerate(sources, start=1))
|
||||
if not source_context:
|
||||
source_context = "No search sources were found for this query."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are sFetch AI, a precise search assistant. Answer only from the provided sources. "
|
||||
"Write in a neutral, professional tone. Keep the response concise. "
|
||||
"Cite sources inline using bracket numbers such as [1]. "
|
||||
"If the sources are insufficient, say what is missing rather than guessing."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Search query: {query}\n\nSources:\n{source_context}",
|
||||
},
|
||||
]
|
||||
response = await ollama_chat(model=model, messages=messages, think=request.think)
|
||||
except OllamaCloudError as exc:
|
||||
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
||||
|
||||
message = response.get("message") or {}
|
||||
return AIAnswerResponse(
|
||||
model=response.get("model") or model,
|
||||
content=message.get("content") or "",
|
||||
thinking=message.get("thinking"),
|
||||
sources=sources,
|
||||
configured=is_ollama_configured(),
|
||||
)
|
||||
|
||||
|
||||
async def _build_ai_search_messages(request: AISearchRequest) -> tuple[list[dict[str, str]], list[AISource]]:
|
||||
sources = await _build_ai_search_sources(request)
|
||||
source_context = "\n".join(_source_text(source, index) for index, source in enumerate(sources, start=1))
|
||||
if not source_context:
|
||||
source_context = "No search sources were found for this query."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are sFetch AI, a precise search assistant. Answer only from the provided sources. "
|
||||
"Write in a neutral, useful tone with direct synthesis. "
|
||||
"Cite sources inline using bracket numbers such as [1]. "
|
||||
"If sources are insufficient, say what is missing rather than guessing."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Search query: {request.query.strip()}\n\nSources:\n{source_context}",
|
||||
},
|
||||
]
|
||||
return messages, sources
|
||||
|
||||
|
||||
@app.post("/ai/search/stream")
|
||||
async def ai_search_stream_endpoint(request: AISearchRequest) -> StreamingResponse:
|
||||
model = (request.model or default_model()).strip()
|
||||
query = request.query.strip()
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Model is required.")
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="Query is required.")
|
||||
if not is_ollama_configured():
|
||||
raise HTTPException(status_code=503, detail="Ollama Cloud is not configured. Set OLLAMA_API_KEY.")
|
||||
|
||||
try:
|
||||
messages, sources = await _build_ai_search_messages(request)
|
||||
except OllamaCloudError as exc:
|
||||
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
||||
|
||||
return StreamingResponse(
|
||||
_stream_ollama_events(model=model, messages=messages, think=request.think, sources=sources),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -41,3 +43,43 @@ class CrawlRequest(BaseModel):
|
||||
max_depth: int = Field(default=2, ge=0, le=5)
|
||||
max_pages_per_domain: int = Field(default=50, ge=1, le=500)
|
||||
same_domain_only: bool = True
|
||||
|
||||
|
||||
class AIMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "tool"]
|
||||
content: str = ""
|
||||
thinking: str | None = None
|
||||
tool_name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class AIChatRequest(BaseModel):
|
||||
model: str | None = None
|
||||
messages: list[AIMessage] = Field(min_length=1)
|
||||
think: bool | str | None = None
|
||||
use_web_search: bool = False
|
||||
web_result_limit: int = Field(default=5, ge=1, le=10)
|
||||
|
||||
|
||||
class AISearchRequest(BaseModel):
|
||||
query: str = Field(min_length=1)
|
||||
model: str | None = None
|
||||
include_web: bool = True
|
||||
local_result_limit: int = Field(default=5, ge=1, le=10)
|
||||
web_result_limit: int = Field(default=5, ge=1, le=10)
|
||||
think: bool | str | None = None
|
||||
|
||||
|
||||
class AISource(BaseModel):
|
||||
title: str
|
||||
url: str
|
||||
source_type: Literal["local", "web"]
|
||||
content: str = ""
|
||||
|
||||
|
||||
class AIAnswerResponse(BaseModel):
|
||||
model: str
|
||||
content: str
|
||||
thinking: str | None = None
|
||||
sources: list[AISource] = Field(default_factory=list)
|
||||
configured: bool = True
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Ollama Cloud API client helpers for AI search and chat."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from config import (
|
||||
OLLAMA_API_BASE,
|
||||
OLLAMA_API_KEY,
|
||||
OLLAMA_DEFAULT_MODEL,
|
||||
OLLAMA_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
class OllamaCloudError(RuntimeError):
|
||||
"""Raised when Ollama Cloud is unavailable or rejects a request."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 502) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def is_ollama_configured() -> bool:
|
||||
return bool(OLLAMA_API_KEY)
|
||||
|
||||
|
||||
def default_model() -> str:
|
||||
return OLLAMA_DEFAULT_MODEL
|
||||
|
||||
|
||||
def _headers(require_auth: bool = False) -> dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if OLLAMA_API_KEY:
|
||||
headers["Authorization"] = f"Bearer {OLLAMA_API_KEY}"
|
||||
elif require_auth:
|
||||
raise OllamaCloudError(
|
||||
"Ollama Cloud is not configured. Set OLLAMA_API_KEY before using AI responses.",
|
||||
status_code=503,
|
||||
)
|
||||
return headers
|
||||
|
||||
|
||||
def _normalise_error(response: httpx.Response) -> str:
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError:
|
||||
return response.text.strip() or response.reason_phrase
|
||||
detail = payload.get("error") or payload.get("detail") or payload
|
||||
return str(detail)
|
||||
|
||||
|
||||
async def list_models() -> list[dict[str, Any]]:
|
||||
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(f"{OLLAMA_API_BASE}/api/tags", headers=_headers(require_auth=False))
|
||||
if response.status_code >= 400:
|
||||
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
|
||||
payload = response.json()
|
||||
models = payload.get("models") or []
|
||||
return sorted(models, key=lambda item: item.get("name") or item.get("model") or "")
|
||||
|
||||
|
||||
async def chat(
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
think: bool | str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
}
|
||||
if think is not None and think != "off":
|
||||
payload["think"] = think
|
||||
|
||||
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
f"{OLLAMA_API_BASE}/api/chat",
|
||||
headers=_headers(require_auth=True),
|
||||
json=payload,
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def stream_chat(
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
think: bool | str | None = None,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
if think is not None and think != "off":
|
||||
payload["think"] = think
|
||||
|
||||
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{OLLAMA_API_BASE}/api/chat",
|
||||
headers=_headers(require_auth=True),
|
||||
json=payload,
|
||||
) as response:
|
||||
if response.status_code >= 400:
|
||||
body = (await response.aread()).decode("utf-8", errors="replace")
|
||||
raise OllamaCloudError(body or response.reason_phrase, status_code=response.status_code)
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
clean_line = line.strip()
|
||||
if not clean_line:
|
||||
continue
|
||||
try:
|
||||
yield json.loads(clean_line)
|
||||
except json.JSONDecodeError:
|
||||
yield {"message": {"content": clean_line}}
|
||||
|
||||
|
||||
async def web_search(query: str, max_results: int = 5) -> list[dict[str, Any]]:
|
||||
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
f"{OLLAMA_API_BASE}/api/web_search",
|
||||
headers=_headers(require_auth=True),
|
||||
json={"query": query, "max_results": max(1, min(max_results, 10))},
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
|
||||
payload = response.json()
|
||||
return payload.get("results") or []
|
||||
Reference in New Issue
Block a user