Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions panels/ModelCheckpointComparison/ModelCheckpointComparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from comet_ml import API, ui, APIExperiment, Artifact, start, ExistingExperiment
import json
import pandas as pd
import os
from ast import literal_eval

st.set_page_config(layout="wide")

api_key = st.sidebar.text_input(
"Comet API key:",
value=os.environ.get("COMET_REAL_API_KEY", ""),
type="password"
)
if not api_key:
st.info("Enter your Comet API key on left sidebar", icon="ℹ️")
st.stop()

os.environ["COMET_REAL_API_KEY"] = api_key

api = API(api_key)
experiments = api.get_panel_experiments()

if len(experiments) > 1:
api_experiment = ui.dropdown("Choose one:", experiments)
else:
api_experiment = experiments[0]

selected_metric = st.sidebar.selectbox('Select Metric to Compare:',[data['name'] for data in api_experiment.get_metrics_summary()])

data = api_experiment.get_asset_list('model-element')

def create_df(data):
#Convert metadata from string to json
for item in data:
item['metadata'] = json.loads(item['metadata'])

#Filter data to relevant columns
def filter_keys(data, keys):
return [{key: item[key] for key in keys} for item in data]

keys = ['fileName', 'dir', 'metadata']
filtered_data = filter_keys(data, keys)

# Create DataFrame
df = pd.json_normalize(filtered_data)

#drop the error_message and synced column if it exists (auto-logged by Comet)
if 'metadata.error_message' in df.columns:
df.drop('metadata.error_message', axis=1, inplace=True)
if 'metadata.synced' in df.columns:
df.drop('metadata.synced', axis=1, inplace=True)

df.rename(columns={'fileName': 'Asset-Name', 'dir': 'Model-Name'}, inplace=True)

df['Model-Name'] = df['Model-Name'].str.replace('models/', '')

# Swap the first and second columns
cols = df.columns.tolist()
cols[0], cols[1] = cols[1], cols[0]


#Set step and epoch as 3rd & 4th columns if they exist
if 'metadata.step' in cols:
if 'metadata.epoch' in cols:
cols.remove('metadata.step')
cols.insert(2, 'metadata.step')
cols.remove('metadata.epoch')
cols.insert(3, 'metadata.epoch')
else:
cols.remove('metadata.step')
cols.insert(2, 'metadata.step')
elif 'metadata.epoch' in cols:
cols.remove('metadata.epoch')
cols.insert(2, 'metadata.epoch')

df = df[cols]
return df

if len(data) > 0:
df = create_df(data)
if 'metadata.epoch' in df.columns and 'metadata.step' in df.columns:
step_or_epoch = st.sidebar.radio('Step or Epoch:', ["epoch", "step"], help="Map metric values to each checkpoint based on step or epoch")
elif 'metadata.epoch' in df.columns:
step_or_epoch = 'epoch'
elif 'metadata.step' in df.columns:
step_or_epoch = 'step'
else:
print("You must log 'step' or 'epoch' with your model metadata in order to use this panel")
exit()
metric_goal = st.sidebar.radio('Metric goal:', ["maximize", "minimize"])

#Map to metric value at each step/epoch
metric_data = api_experiment.get_metrics(selected_metric)
epoch_to_metric_value = {dp[step_or_epoch]: literal_eval(dp["metricValue"]) for dp in metric_data}
df[selected_metric] = df[f"metadata.{step_or_epoch}"].map(epoch_to_metric_value)

if metric_goal == "maximize":
df.sort_values(by=[selected_metric], inplace=True, ascending = False)
else:
df.sort_values(by=[selected_metric], inplace=True, ascending = True)


df.reset_index(drop=True, inplace=True)

st.dataframe(df, use_container_width=True)

# Model registration form
st.header("Create Model Artifact")

# Select a model to register
model_names = df['Model-Name'].unique().tolist()
selected_model = st.selectbox("Select a model:", model_names)

# Model name input
workspace = api.get_panel_workspace()
artifacts = api.get_artifact_list(workspace)["artifacts"]
artifact_names = [artifact["name"] for artifact in artifacts]
artifact_name = st.selectbox("Select an Artifact:", artifact_names, help = "A new version will be created within this artifact")
def create_artifact():
try:
api_experiment.download_model(selected_model, f"./{selected_model}")
# Filter and get the "Asset-Name" where "Model-Name" is "checkpoint_0"
asset_name = df.loc[df["Model-Name"] == selected_model, "Asset-Name"].values[0]
path = f'./{selected_model}/{asset_name}'
artifact = Artifact(name=artifact_name, metadata = {"model_checkpoint":path, f"{step_or_epoch}": df.loc[df["Model-Name"] == selected_model, f"metadata.{step_or_epoch}"].values[0]})
artifact.add(path)
#exp = ExistingExperiment(api_key = api_key, experiment_key=api_experiment.key)
exp = start(api_key = api_key, mode="get", experiment_key=api_experiment.key)
exp.log_artifact(artifact)
st.success("Artifact Created Successfully!")
exp.end()
except Exception as e:
ui.display_text(f"Error registering model: {str(e)}")

st.button("Create Model Artifact", on_click=create_artifact)

else:
ui.display('No models logged to this experiment')
67 changes: 67 additions & 0 deletions panels/ModelCheckpointComparison/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
### ModelCheckpointComparison

The `ModelCheckpointComparison` panel is used to compare performance of your model at each of the checkpoints logged. This is a useful tool to help determine which of your model checkpoints is best performing and should be promoted via the registry.

<table>
<tr>
<td>
<img src="https://raw.githubusercontent.com/comet-ml/comet-examples/refs/heads/master/panels/ModelCheckpointComparison/model-comparison-panel.png"
style="max-width: 300px; max-height: 300px;">
</img>
</td>
</tr>
</table>

First, run your experiment, including logging the model checkpoints and metrics at each step/epoch in your training loop. Each model checkpoint should log the step or epoch to the metadata field, and be uniquely named based on step/epoch, so that the panel can later match each checkpoint to performance at that step/epoch.

```python
#Log the model checkpoint directly to Comet at each epoch
for i in range(10):
experiment.log_model(f'checkpoint_{i}', '/path/to/your/model.pkl', metadata = {'epoch': i})
experiment.log_metric('metric1', i, epoch=i)
experiment.log_metric('metric2', 50-i, epoch=i)


#Or log a pointer to the model checkpoint at each epoch
for i in range(10):
experiment.log_remote_model(f'checkpoint_{i}', '/path/to/your/model.pkl', metadata = {'epoch': i})
experiment.log_metric('metric1', i, epoch=i)
experiment.log_metric('metric2', 50-i, epoch=i)
```

Finally click on "Select Experiment with log:" in this panel.

#### Example

This example logs some dummy metric + model checkpoint data to Comet so that you can test out the panel.

```python
import comet_ml

#Start Comet experiment
comet_ml.login()
experiment = comet_ml.start(project_name="tf-profiler")

for i in range(10):
experiment.log_remote_model(f'checkpoint_{i}', '/path/to/your/model.pkl', metadata = {'epoch': i})
experiment.log_metric('metric1', i, epoch=i)
experiment.log_metric('metric2', 50-i, epoch=i)

experiment.end()
```

#### Python Panel

To include this panel from the github repo, use this code in a Custom Python Panel:

```
%include https://raw.githubusercontent.com/comet-ml/comet-examples/refs/heads/master/panels/ModelCheckpointComparison/ModelCheckpointComparison.py
```

Or, you can simply [copy the code](https://raw.githubusercontent.com/comet-ml/comet-examples/refs/heads/master/panels/ModelCheckpointComparison/ModelCheckpointComparison.py) into a custom Python Panel.

#### How it works

The Python panel will retrieve a list of your model checkpoints, then use the epoch values logged to the checkpoint metadata to fine the value of the specific metric at that epoch.


Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading