"""Ollama Cloud API client helpers for AI search and chat.""" from __future__ import annotations import json from collections.abc import AsyncIterator from typing import Any import httpx from config import ( OLLAMA_API_BASE, OLLAMA_API_KEY, OLLAMA_DEFAULT_MODEL, OLLAMA_REQUEST_TIMEOUT_SECONDS, ) class OllamaCloudError(RuntimeError): """Raised when Ollama Cloud is unavailable or rejects a request.""" def __init__(self, message: str, status_code: int = 502) -> None: super().__init__(message) self.status_code = status_code def is_ollama_configured() -> bool: return bool(OLLAMA_API_KEY) def default_model() -> str: return OLLAMA_DEFAULT_MODEL def _headers(require_auth: bool = False) -> dict[str, str]: headers = {"Content-Type": "application/json"} if OLLAMA_API_KEY: headers["Authorization"] = f"Bearer {OLLAMA_API_KEY}" elif require_auth: raise OllamaCloudError( "Ollama Cloud is not configured. Set OLLAMA_API_KEY before using AI responses.", status_code=503, ) return headers def _normalise_error(response: httpx.Response) -> str: try: payload = response.json() except ValueError: return response.text.strip() or response.reason_phrase detail = payload.get("error") or payload.get("detail") or payload return str(detail) async def list_models() -> list[dict[str, Any]]: timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS) async with httpx.AsyncClient(timeout=timeout) as client: response = await client.get(f"{OLLAMA_API_BASE}/api/tags", headers=_headers(require_auth=False)) if response.status_code >= 400: raise OllamaCloudError(_normalise_error(response), status_code=response.status_code) payload = response.json() models = payload.get("models") or [] return sorted(models, key=lambda item: item.get("name") or item.get("model") or "") async def chat( model: str, messages: list[dict[str, Any]], think: bool | str | None = None, ) -> dict[str, Any]: payload: dict[str, Any] = { "model": model, "messages": messages, "stream": False, } if think is not None and think != "off": payload["think"] = think timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS) async with httpx.AsyncClient(timeout=timeout) as client: response = await client.post( f"{OLLAMA_API_BASE}/api/chat", headers=_headers(require_auth=True), json=payload, ) if response.status_code >= 400: raise OllamaCloudError(_normalise_error(response), status_code=response.status_code) return response.json() async def stream_chat( model: str, messages: list[dict[str, Any]], think: bool | str | None = None, ) -> AsyncIterator[dict[str, Any]]: payload: dict[str, Any] = { "model": model, "messages": messages, "stream": True, } if think is not None and think != "off": payload["think"] = think timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS) async with httpx.AsyncClient(timeout=timeout) as client: async with client.stream( "POST", f"{OLLAMA_API_BASE}/api/chat", headers=_headers(require_auth=True), json=payload, ) as response: if response.status_code >= 400: body = (await response.aread()).decode("utf-8", errors="replace") raise OllamaCloudError(body or response.reason_phrase, status_code=response.status_code) async for line in response.aiter_lines(): clean_line = line.strip() if not clean_line: continue try: yield json.loads(clean_line) except json.JSONDecodeError: yield {"message": {"content": clean_line}} async def web_search(query: str, max_results: int = 5) -> list[dict[str, Any]]: timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS) async with httpx.AsyncClient(timeout=timeout) as client: response = await client.post( f"{OLLAMA_API_BASE}/api/web_search", headers=_headers(require_auth=True), json={"query": query, "max_results": max(1, min(max_results, 10))}, ) if response.status_code >= 400: raise OllamaCloudError(_normalise_error(response), status_code=response.status_code) payload = response.json() return payload.get("results") or []