Add AI functionality; fuck up UI royally, still a piece of shit.
This commit is contained in:
@@ -0,0 +1,138 @@
|
||||
"""Ollama Cloud API client helpers for AI search and chat."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from config import (
|
||||
OLLAMA_API_BASE,
|
||||
OLLAMA_API_KEY,
|
||||
OLLAMA_DEFAULT_MODEL,
|
||||
OLLAMA_REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
class OllamaCloudError(RuntimeError):
|
||||
"""Raised when Ollama Cloud is unavailable or rejects a request."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 502) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def is_ollama_configured() -> bool:
|
||||
return bool(OLLAMA_API_KEY)
|
||||
|
||||
|
||||
def default_model() -> str:
|
||||
return OLLAMA_DEFAULT_MODEL
|
||||
|
||||
|
||||
def _headers(require_auth: bool = False) -> dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if OLLAMA_API_KEY:
|
||||
headers["Authorization"] = f"Bearer {OLLAMA_API_KEY}"
|
||||
elif require_auth:
|
||||
raise OllamaCloudError(
|
||||
"Ollama Cloud is not configured. Set OLLAMA_API_KEY before using AI responses.",
|
||||
status_code=503,
|
||||
)
|
||||
return headers
|
||||
|
||||
|
||||
def _normalise_error(response: httpx.Response) -> str:
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError:
|
||||
return response.text.strip() or response.reason_phrase
|
||||
detail = payload.get("error") or payload.get("detail") or payload
|
||||
return str(detail)
|
||||
|
||||
|
||||
async def list_models() -> list[dict[str, Any]]:
|
||||
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(f"{OLLAMA_API_BASE}/api/tags", headers=_headers(require_auth=False))
|
||||
if response.status_code >= 400:
|
||||
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
|
||||
payload = response.json()
|
||||
models = payload.get("models") or []
|
||||
return sorted(models, key=lambda item: item.get("name") or item.get("model") or "")
|
||||
|
||||
|
||||
async def chat(
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
think: bool | str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
}
|
||||
if think is not None and think != "off":
|
||||
payload["think"] = think
|
||||
|
||||
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
f"{OLLAMA_API_BASE}/api/chat",
|
||||
headers=_headers(require_auth=True),
|
||||
json=payload,
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def stream_chat(
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
think: bool | str | None = None,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
if think is not None and think != "off":
|
||||
payload["think"] = think
|
||||
|
||||
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{OLLAMA_API_BASE}/api/chat",
|
||||
headers=_headers(require_auth=True),
|
||||
json=payload,
|
||||
) as response:
|
||||
if response.status_code >= 400:
|
||||
body = (await response.aread()).decode("utf-8", errors="replace")
|
||||
raise OllamaCloudError(body or response.reason_phrase, status_code=response.status_code)
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
clean_line = line.strip()
|
||||
if not clean_line:
|
||||
continue
|
||||
try:
|
||||
yield json.loads(clean_line)
|
||||
except json.JSONDecodeError:
|
||||
yield {"message": {"content": clean_line}}
|
||||
|
||||
|
||||
async def web_search(query: str, max_results: int = 5) -> list[dict[str, Any]]:
|
||||
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
f"{OLLAMA_API_BASE}/api/web_search",
|
||||
headers=_headers(require_auth=True),
|
||||
json={"query": query, "max_results": max(1, min(max_results, 10))},
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
|
||||
payload = response.json()
|
||||
return payload.get("results") or []
|
||||
Reference in New Issue
Block a user