Files

139 lines
4.5 KiB
Python

"""Ollama Cloud API client helpers for AI search and chat."""
from __future__ import annotations
import json
from collections.abc import AsyncIterator
from typing import Any
import httpx
from config import (
OLLAMA_API_BASE,
OLLAMA_API_KEY,
OLLAMA_DEFAULT_MODEL,
OLLAMA_REQUEST_TIMEOUT_SECONDS,
)
class OllamaCloudError(RuntimeError):
"""Raised when Ollama Cloud is unavailable or rejects a request."""
def __init__(self, message: str, status_code: int = 502) -> None:
super().__init__(message)
self.status_code = status_code
def is_ollama_configured() -> bool:
return bool(OLLAMA_API_KEY)
def default_model() -> str:
return OLLAMA_DEFAULT_MODEL
def _headers(require_auth: bool = False) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if OLLAMA_API_KEY:
headers["Authorization"] = f"Bearer {OLLAMA_API_KEY}"
elif require_auth:
raise OllamaCloudError(
"Ollama Cloud is not configured. Set OLLAMA_API_KEY before using AI responses.",
status_code=503,
)
return headers
def _normalise_error(response: httpx.Response) -> str:
try:
payload = response.json()
except ValueError:
return response.text.strip() or response.reason_phrase
detail = payload.get("error") or payload.get("detail") or payload
return str(detail)
async def list_models() -> list[dict[str, Any]]:
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(f"{OLLAMA_API_BASE}/api/tags", headers=_headers(require_auth=False))
if response.status_code >= 400:
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
payload = response.json()
models = payload.get("models") or []
return sorted(models, key=lambda item: item.get("name") or item.get("model") or "")
async def chat(
model: str,
messages: list[dict[str, Any]],
think: bool | str | None = None,
) -> dict[str, Any]:
payload: dict[str, Any] = {
"model": model,
"messages": messages,
"stream": False,
}
if think is not None and think != "off":
payload["think"] = think
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{OLLAMA_API_BASE}/api/chat",
headers=_headers(require_auth=True),
json=payload,
)
if response.status_code >= 400:
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
return response.json()
async def stream_chat(
model: str,
messages: list[dict[str, Any]],
think: bool | str | None = None,
) -> AsyncIterator[dict[str, Any]]:
payload: dict[str, Any] = {
"model": model,
"messages": messages,
"stream": True,
}
if think is not None and think != "off":
payload["think"] = think
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST",
f"{OLLAMA_API_BASE}/api/chat",
headers=_headers(require_auth=True),
json=payload,
) as response:
if response.status_code >= 400:
body = (await response.aread()).decode("utf-8", errors="replace")
raise OllamaCloudError(body or response.reason_phrase, status_code=response.status_code)
async for line in response.aiter_lines():
clean_line = line.strip()
if not clean_line:
continue
try:
yield json.loads(clean_line)
except json.JSONDecodeError:
yield {"message": {"content": clean_line}}
async def web_search(query: str, max_results: int = 5) -> list[dict[str, Any]]:
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{OLLAMA_API_BASE}/api/web_search",
headers=_headers(require_auth=True),
json={"query": query, "max_results": max(1, min(max_results, 10))},
)
if response.status_code >= 400:
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
payload = response.json()
return payload.get("results") or []