Skip to content

Commit

Permalink
Merge pull request #71 from x-tabdeveloping/color_palette
Browse files Browse the repository at this point in the history
Color palette
  • Loading branch information
x-tabdeveloping authored Nov 5, 2024
2 parents 1629496 + ecd4ef2 commit 1da38f1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ line-length=79

[tool.poetry]
name = "turftopic"
version = "0.8.0"
version = "0.8.1"
description = "Topic modeling with contextual representations from sentence transformers."
authors = ["Márton Kardos <[email protected]>"]
license = "MIT"
Expand Down
36 changes: 33 additions & 3 deletions turftopic/dynamic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Any, Optional, Union
from typing import Any, Iterable, Optional, Union

import numpy as np
from rich.console import Console
Expand Down Expand Up @@ -273,7 +274,12 @@ def export_topics_over_time(
table = self._topics_over_time(top_k, show_scores, date_format)
return export_table(table, format=format)

def plot_topics_over_time(self, top_k: int = 6):
def plot_topics_over_time(
self,
top_k: int = 6,
color_discrete_sequence: Optional[Iterable[str]] = None,
color_discrete_map: Optional[dict[str, str]] = None,
):
"""Displays topics over time in the fitted dynamic model on a dynamic HTML figure.
> You will need to `pip install plotly` to use this method.
Expand All @@ -282,6 +288,18 @@ def plot_topics_over_time(self, top_k: int = 6):
----------
top_k: int, default 6
Number of top words per topic to display on the figure.
color_discrete_sequence: Iterable[str], default None
Color palette to use in the plot.
Example:
```python
import plotly.express as px
model.plot_topics_over_time(color_discrete_sequence=px.colors.qualitative.Light24)
```
color_discrete_map: dict[str, str], default None
Topic names mapped to the colors that should
be associated with them.
Returns
-------
Expand All @@ -296,14 +314,25 @@ def plot_topics_over_time(self, top_k: int = 6):
raise ModuleNotFoundError(
"Please install plotly if you intend to use plots in Turftopic."
) from e
if color_discrete_sequence is not None:
topic_colors = itertools.cycle(color_discrete_sequence)
elif color_discrete_map is not None:
topic_colors = [
color_discrete_map[topic_name]
for topic_name in self.topic_names
]
else:
topic_colors = px.colors.qualitative.Dark24
fig = go.Figure()
vocab = self.get_vocab()
n_topics = self.temporal_components_.shape[1]
try:
topic_names = self.topic_names
except AttributeError:
topic_names = [f"Topic {i}" for i in range(n_topics)]
for i_topic, topic_imp_t in enumerate(self.temporal_importance_.T):
for trace_color, (i_topic, topic_imp_t) in zip(
topic_colors, enumerate(self.temporal_importance_.T)
):
component_over_time = self.temporal_components_[:, i_topic, :]
name_over_time = []
for component in component_over_time:
Expand All @@ -326,6 +355,7 @@ def plot_topics_over_time(self, top_k: int = 6):
marker=dict(
line=dict(width=2, color="black"),
size=14,
color=trace_color,
),
line=dict(width=3),
)
Expand Down

0 comments on commit 1da38f1

Please sign in to comment.