Coder Social home page Coder Social logo

Comments (5)

seabull avatar seabull commented on April 28, 2024 1

Thank you @Cesar-Cardoso. I will try it out and post updates here.

from ax.

Cesar-Cardoso avatar Cesar-Cardoso commented on April 28, 2024

Hi @seabull! Could you provide a minimal reproducible example for this issue. From the stack trace it's not obvious that the issue is directly related to the decoding logic. In particular the error is pointing out that your data_ingestor object is missing the property __dict__ you're trying to reference; but it's hard to get much further without the rest of the code.

from ax.

seabull avatar seabull commented on April 28, 2024

Thanks for the response.
Here is the __init__ that generates the error in the stack trace. I hope this provides some context:

class GeneralFactorialMetric(Metric):
    def __init__(
        self,
        data_ingestor: SimulateTreatmentData,
        name="barry",
        *args,
        **kwargs,
    ):
        super().__init__(name=name, *args, **kwargs)
        self.data_ingestor = data_ingestor
        self.experiment_values = {}
        print(f"{data_ingestor=}")
        print(f"{data_ingestor.__dict__=}")

As shown in the stack trace, the print(f"{data_ingestor.__dict__=}") statement complains that data_ingestor is a dict, not an object (of SimulateTreatmentData class). If the decoder works, we would expect data_ingestor.__dict__ to print successfully.

from ax.

seabull avatar seabull commented on April 28, 2024

I know this example below is a little long (sorry!) but that is the minimal I could reproduce the error.

# %%
from typing import Any, Dict, Optional

from ax import (
    Arm,
    ChoiceParameter,
    Experiment,
    Metric,
    Objective,
    OptimizationConfig,
    ParameterType,
    Runner,
    SearchSpace,
)
from ax.storage.json_store.load import load_experiment
from ax.storage.json_store.registry import CORE_DECODER_REGISTRY, CORE_ENCODER_REGISTRY
from ax.storage.json_store.save import save_experiment
from ax.storage.registry_bundle import RegistryBundle
from ax.utils.common.serialization import SerializationMixin, serialize_init_args
from ax.utils.common.typeutils import checked_cast


class SimulateTreatmentData(SerializationMixin):
    def __init__(
        self,
        search_space_parameters: Optional[list[list[str]]] = None,
        num_minibatches_per_update: int = 3,
    ):
        self.search_space_parameters = search_space_parameters
        self.num_minibatches_per_update = num_minibatches_per_update

    @classmethod
    # pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
    def serialize_init_args(cls, obj: Any) -> Dict[str, Any]:
        sim = checked_cast(SimulateTreatmentData, obj)
        properties = serialize_init_args(obj=sim)
        properties["search_space_parameters"] = sim.search_space_parameters
        properties["num_minibatches_per_update"] = sim.num_minibatches_per_update
        return properties


class GeneralFactorialMetric(Metric):
    def __init__(
        self,
        data_ingestor: SimulateTreatmentData,
        name="barry",
        *args,
        **kwargs,
    ):
        super().__init__(name=name, *args, **kwargs)
        self.data_ingestor = data_ingestor
        print(f"{data_ingestor.__dict__=}")  # it is empty


my_treatments = ["treatment-0", "treatment-1", "treatment-2"]


class MyRunner(Runner):
    def run(self, trial):
        trial_metadata = {"name": str(trial.index)}
        return trial_metadata


search_space = SearchSpace(
    parameters=[
        ChoiceParameter(
            name="arm_name",  # "factor" + str(ifactor),
            parameter_type=ParameterType.STRING,
            values=my_treatments,
            is_ordered=False,
            sort_values=False,
        ),
    ]
)

data_simulator = SimulateTreatmentData(
    search_space_parameters=[par.values for par in search_space.parameters.values()],
)

exp = Experiment(
    name="my_factorial_closed_loop_experiment",
    search_space=search_space,
    optimization_config=OptimizationConfig(
        objective=Objective(
            metric=GeneralFactorialMetric(
                data_ingestor=data_simulator,
                name="success_metric",
            )
        )
    ),
    runner=MyRunner(),
)
arms = [
    Arm(parameters={"arm_name": treatment}, name=treatment)
    for treatment in my_treatments
]
trial = exp.new_batch_trial().add_arms_and_weights(arms=arms, weights=[0.4, 0.3, 0.3])


def treatment_data_to_dict(sim: SimulateTreatmentData) -> Dict[str, Any]:
    """Convert SimulateTreatmentData to a dictionary."""
    properties = sim.serialize_init_args(obj=sim)
    properties["__type"] = sim.__class__.__name__
    return properties


core_encoders = {
    **CORE_ENCODER_REGISTRY,
    SimulateTreatmentData: treatment_data_to_dict,
}


# class djfslkj:
#     def __init__(self, *args, **kwargs):
#         raise ValueError("Foo **********nfdklajgkladjkljflkjalkjflkjlkdj")

decoders = {
    **CORE_DECODER_REGISTRY,
    # "SimulateBinomialTreatmentData": djfslkj,
    # "SimulateTreatmentData": djfslkj,
    "SimulateTreatmentData": SimulateTreatmentData,
}

bundle = RegistryBundle(
    runner_clss={
        MyRunner: None,
    },
    metric_clss={
        GeneralFactorialMetric: None,
    },
    json_encoder_registry=core_encoders,
    json_decoder_registry=decoders,
)
filepath = "temp/my_experiment.json"
save_experiment(exp, filepath, encoder_registry=bundle.encoder_registry)
exp = load_experiment(filepath=filepath, decoder_registry=bundle.decoder_registry)

from ax.

Cesar-Cardoso avatar Cesar-Cardoso commented on April 28, 2024

Thanks for the examples! The issue happens when decoding the GeneralFactorialMetric object. In particular the data_ingestor property gets loaded as a dict, which causes the issue. A solution is to override deserialize_init_args for GeneralFactorialMetric as such:

class GeneralFactorialMetric(Metric):
    def __init__(
        self,
        data_ingestor: SimulateTreatmentData,
        name="barry",
        *args,
        **kwargs,
    ):
        super().__init__(name=name, *args, **kwargs)
        self.data_ingestor = data_ingestor
        print(f"{data_ingestor.__dict__=}")  # it is empty
    
    @classmethod
    def deserialize_init_args(
        cls,
        args: Dict[str, Any],
        decoder_registry: Optional[TDecoderRegistry] = None,
        class_decoder_registry: Optional[TClassDecoderRegistry] = None):
            return {
                "data_ingestor": SimulateTreatmentData(
                    **SimulateTreatmentData.deserialize_init_args(
                        args=args['data_ingestor'],
                        decoder_registry=decoder_registry,
                        class_decoder_registry=class_decoder_registry)),
                "name": args.get("name") or "barry"
            }

With this change I managed to encode/decode an experiment using your code. Let me know if this solves your issue. Cheers!

from ax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.