Skip to content

Commit

Permalink
Support multiple tags in each endpoint (#687)
Browse files Browse the repository at this point in the history
Currently when an endpoint has multiple tags, the first tag is used and
everything else is ignored.

This PR modifies it so endpoints with multiple tags are added to each of
the tags.

Yes, this results in repeated code 😅, but works beautifully
and functions can now be found anywhere we expect them to be.

---------

Co-authored-by: Dylan Anthony <[email protected]>
Co-authored-by: Dylan Anthony <[email protected]>
  • Loading branch information
3 people authored Dec 24, 2024
1 parent 7225f0e commit 88b3be1
Show file tree
Hide file tree
Showing 14 changed files with 175 additions and 120 deletions.
8 changes: 8 additions & 0 deletions .changeset/add_generate_all_tags_config_option.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
default: minor
---

# Add `generate_all_tags` config option

You can now, optionally, generate **duplicate** endpoint functions/modules using _every_ tag for an endpoint,
not just the first one, by setting `generate_all_tags: true` in your configuration file.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ literal_enums: true

This is especially useful if enum values, when transformed to their Python names, end up conflicting due to case sensitivity or special symbols.

### generate_all_tags

`openapi-python-client` generates module names within the `api` module based on the OpenAPI `tags` of each endpoint.
By default, only the _first_ tag is generated. If you want to generate **duplicate** endpoint functions using _every_ tag
listed, you can enable this option:

```yaml
generate_all_tags: true
```

### project_name_override and package_name_override

Used to change the name of generated client library project/package. If the project name is changed but an override for the package name
Expand Down
14 changes: 14 additions & 0 deletions end_to_end_tests/__snapshots__/test_end_to_end.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# serializer version: 1
# name: test_documents_with_errors[bad-status-code]
'''
Generating /test-documents-with-errors
Warning(s) encountered while generating. Client was generated, but some pieces may be missing

WARNING parsing GET / within default.

Invalid response status code abcdef (not a valid HTTP status code), response will be omitted from generated client


If you believe this was a mistake or this tool is missing a feature you need, please open an issue at https://github.com/openapi-generators/openapi-python-client/issues/new/choose

'''
# ---
# name: test_documents_with_errors[circular-body-ref]
'''
Generating /test-documents-with-errors
Expand Down
4 changes: 1 addition & 3 deletions end_to_end_tests/baseline_openapi_3.0.json
Original file line number Diff line number Diff line change
Expand Up @@ -1149,9 +1149,7 @@
},
"/tag_with_number": {
"get": {
"tags": [
"1"
],
"tags": ["1", "2"],
"responses": {
"200": {
"description": "Success"
Expand Down
4 changes: 1 addition & 3 deletions end_to_end_tests/baseline_openapi_3.1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1141,9 +1141,7 @@ info:
},
"/tag_with_number": {
"get": {
"tags": [
"1"
],
"tags": ["1", "2"],
"responses": {
"200": {
"description": "Success"
Expand Down
1 change: 1 addition & 0 deletions end_to_end_tests/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ class_overrides:
field_prefix: attr_
content_type_overrides:
openapi/python/client: application/json
generate_all_tags: true
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .parameters import ParametersEndpoints
from .responses import ResponsesEndpoints
from .tag1 import Tag1Endpoints
from .tag2 import Tag2Endpoints
from .tests import TestsEndpoints
from .true_ import True_Endpoints

Expand Down Expand Up @@ -48,6 +49,10 @@ def parameters(cls) -> type[ParametersEndpoints]:
def tag1(cls) -> type[Tag1Endpoints]:
return Tag1Endpoints

@classmethod
def tag2(cls) -> type[Tag2Endpoints]:
return Tag2Endpoints

@classmethod
def location(cls) -> type[LocationEndpoints]:
return LocationEndpoints
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Contains methods for accessing the API Endpoints"""

import types

from . import get_tag_with_number


class Tag2Endpoints:
@classmethod
def get_tag_with_number(cls) -> types.ModuleType:
return get_tag_with_number
14 changes: 14 additions & 0 deletions end_to_end_tests/documents_with_errors/bad-status-code.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
openapi: "3.1.0"
info:
title: "There's something wrong with me"
version: "0.1.0"
paths:
"/":
get:
responses:
"abcdef":
description: "Successful Response"
content:
"application/json":
schema:
const: "Why have a fixed response? I dunno"
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from http import HTTPStatus
from typing import Any, Optional, Union

import httpx

from ... import errors
from ...client import AuthenticatedClient, Client
from ...types import Response


def _get_kwargs() -> dict[str, Any]:
_kwargs: dict[str, Any] = {
"method": "get",
"url": "/tag_with_number",
}

return _kwargs


def _parse_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Optional[Any]:
if response.status_code == 200:
return None
if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
else:
return None


def _build_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Response[Any]:
return Response(
status_code=HTTPStatus(response.status_code),
content=response.content,
headers=response.headers,
parsed=_parse_response(client=client, response=response),
)


def sync_detailed(
*,
client: Union[AuthenticatedClient, Client],
) -> Response[Any]:
"""
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Response[Any]
"""

kwargs = _get_kwargs()

response = client.get_httpx_client().request(
**kwargs,
)

return _build_response(client=client, response=response)


async def asyncio_detailed(
*,
client: Union[AuthenticatedClient, Client],
) -> Response[Any]:
"""
Raises:
errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
httpx.TimeoutException: If the request takes longer than Client.timeout.
Returns:
Response[Any]
"""

kwargs = _get_kwargs()

response = await client.get_async_httpx_client().request(**kwargs)

return _build_response(client=client, response=response)
3 changes: 3 additions & 0 deletions openapi_python_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ConfigFile(BaseModel):
use_path_prefixes_for_title_model_names: bool = True
post_hooks: Optional[list[str]] = None
field_prefix: str = "field_"
generate_all_tags: bool = False
http_timeout: int = 5
literal_enums: bool = False

Expand Down Expand Up @@ -70,6 +71,7 @@ class Config:
use_path_prefixes_for_title_model_names: bool
post_hooks: list[str]
field_prefix: str
generate_all_tags: bool
http_timeout: int
literal_enums: bool
document_source: Union[Path, str]
Expand Down Expand Up @@ -110,6 +112,7 @@ def from_sources(
use_path_prefixes_for_title_model_names=config_file.use_path_prefixes_for_title_model_names,
post_hooks=post_hooks,
field_prefix=config_file.field_prefix,
generate_all_tags=config_file.generate_all_tags,
http_timeout=config_file.http_timeout,
literal_enums=config_file.literal_enums,
document_source=document_source,
Expand Down
32 changes: 19 additions & 13 deletions openapi_python_client/parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,18 @@ def from_data(
operation: Optional[oai.Operation] = getattr(path_data, method)
if operation is None:
continue
tag = utils.PythonIdentifier(value=(operation.tags or ["default"])[0], prefix="tag")
collection = endpoints_by_tag.setdefault(tag, EndpointCollection(tag=tag))

tags = [utils.PythonIdentifier(value=tag, prefix="tag") for tag in operation.tags or ["default"]]
if not config.generate_all_tags:
tags = tags[:1]

collections = [endpoints_by_tag.setdefault(tag, EndpointCollection(tag=tag)) for tag in tags]

endpoint, schemas, parameters = Endpoint.from_data(
data=operation,
path=path,
method=method,
tag=tag,
tags=tags,
schemas=schemas,
parameters=parameters,
request_bodies=request_bodies,
Expand All @@ -87,15 +92,16 @@ def from_data(
if not isinstance(endpoint, ParseError):
endpoint = Endpoint.sort_parameters(endpoint=endpoint)
if isinstance(endpoint, ParseError):
endpoint.header = (
f"WARNING parsing {method.upper()} {path} within {tag}. Endpoint will not be generated."
)
collection.parse_errors.append(endpoint)
endpoint.header = f"WARNING parsing {method.upper()} {path} within {'/'.join(tags)}. Endpoint will not be generated."
for collection in collections:
collection.parse_errors.append(endpoint)
continue
for error in endpoint.errors:
error.header = f"WARNING parsing {method.upper()} {path} within {tag}."
collection.parse_errors.append(error)
collection.endpoints.append(endpoint)
error.header = f"WARNING parsing {method.upper()} {path} within {'/'.join(tags)}."
for collection in collections:
collection.parse_errors.append(error)
for collection in collections:
collection.endpoints.append(endpoint)

return endpoints_by_tag, schemas, parameters

Expand Down Expand Up @@ -132,7 +138,7 @@ class Endpoint:
description: Optional[str]
name: str
requires_security: bool
tag: str
tags: list[PythonIdentifier]
summary: Optional[str] = ""
relative_imports: set[str] = field(default_factory=set)
query_parameters: list[Property] = field(default_factory=list)
Expand Down Expand Up @@ -393,7 +399,7 @@ def from_data(
data: oai.Operation,
path: str,
method: str,
tag: str,
tags: list[PythonIdentifier],
schemas: Schemas,
parameters: Parameters,
request_bodies: dict[str, Union[oai.RequestBody, oai.Reference]],
Expand All @@ -413,7 +419,7 @@ def from_data(
description=utils.remove_string_escapes(data.description) if data.description else "",
name=name,
requires_security=bool(data.security),
tag=tag,
tags=tags,
)

result, schemas, parameters = Endpoint.add_parameters(
Expand Down
Loading

0 comments on commit 88b3be1

Please sign in to comment.